mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refactor: Refactored code by Merging isinstance calls (#7710)
* Merged isinstance calls to make the code simpler. * Corrected formatting errors using ruff. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -460,7 +460,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
|
||||
if isinstance(image, (list, np.ndarray, torch.Tensor)):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
|
||||
@@ -685,7 +685,7 @@ class UNet2DConditionModel(
|
||||
positive_len = 768
|
||||
if isinstance(cross_attention_dim, int):
|
||||
positive_len = cross_attention_dim
|
||||
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
||||
elif isinstance(cross_attention_dim, (list, tuple)):
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
|
||||
@@ -817,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
positive_len = 768
|
||||
if isinstance(cross_attention_dim, int):
|
||||
positive_len = cross_attention_dim
|
||||
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
||||
elif isinstance(cross_attention_dim, (list, tuple)):
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
|
||||
@@ -197,7 +197,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, np.ndarray):
|
||||
if isinstance(image, (list, np.ndarray)):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
|
||||
@@ -221,7 +221,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor):
|
||||
if isinstance(image, (list, torch.Tensor)):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
|
||||
@@ -468,7 +468,7 @@ class StableDiffusionUpscalePipeline(
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
|
||||
if isinstance(image, (list, np.ndarray, torch.Tensor)):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
|
||||
@@ -185,7 +185,7 @@ def preprocess(image):
|
||||
def preprocess_mask(mask, batch_size: int = 1):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
# preprocess mask
|
||||
if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray):
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list):
|
||||
|
||||
@@ -347,11 +347,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
|
||||
@@ -310,11 +310,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
|
||||
@@ -375,11 +375,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
|
||||
@@ -530,11 +530,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
|
||||
Reference in New Issue
Block a user