1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

refactor: Refactored code by Merging isinstance calls (#7710)

* Merged isinstance calls to make the code simpler.

* Corrected formatting errors using ruff.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Sai-Suraj-27
2024-05-16 10:03:19 +05:30
committed by GitHub
parent 0f111ab794
commit 2afea72d29
11 changed files with 11 additions and 27 deletions

View File

@@ -460,7 +460,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(image, (list, np.ndarray, torch.Tensor)):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):

View File

@@ -685,7 +685,7 @@ class UNet2DConditionModel(
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"

View File

@@ -817,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"

View File

@@ -197,7 +197,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, np.ndarray):
if isinstance(image, (list, np.ndarray)):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):

View File

@@ -221,7 +221,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
)
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
if isinstance(image, (list, torch.Tensor)):
if isinstance(prompt, str):
batch_size = 1
else:

View File

@@ -468,7 +468,7 @@ class StableDiffusionUpscalePipeline(
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(image, (list, np.ndarray, torch.Tensor)):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):

View File

@@ -185,7 +185,7 @@ def preprocess(image):
def preprocess_mask(mask, batch_size: int = 1):
if not isinstance(mask, torch.Tensor):
# preprocess mask
if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray):
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list):

View File

@@ -347,11 +347,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"

View File

@@ -310,11 +310,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"

View File

@@ -375,11 +375,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"

View File

@@ -530,11 +530,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"