diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4bae7b9fa6..f9d3402d06 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -301,7 +301,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" - if torch.is_floating_point(timesteps): + if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 7637339a84..d1a3d4c55e 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -379,7 +379,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" - if torch.is_floating_point(timesteps): + if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 9502f69953..5da43be2ad 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -117,8 +117,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): Returns: `jnp.ndarray`: scaled input sample """ - (step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1) - sigma = scheduler_state.sigmas[step_index] + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + sigma = state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index efa4bdc6f3..0b3b69c2c2 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -15,7 +15,6 @@ import gc import tempfile -import time import unittest import numpy as np @@ -694,24 +693,6 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase): assert test_callback_fn.has_been_called assert number_of_steps == 20 - def test_stable_diffusion_low_cpu_mem_usage(self): - pipeline_id = "stabilityai/stable-diffusion-2-base" - - start_time = time.time() - pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float16 - ) - pipeline_low_cpu_mem_usage.to(torch_device) - low_cpu_mem_usage_time = time.time() - start_time - - start_time = time.time() - _ = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False - ) - normal_load_time = time.time() - start_time - - assert 2 * low_cpu_mem_usage_time < normal_load_time - def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated()