mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix tests
This commit is contained in:
@@ -301,7 +301,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if torch.is_floating_point(timesteps):
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
|
||||
@@ -379,7 +379,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if torch.is_floating_point(timesteps):
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
|
||||
@@ -117,8 +117,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
(step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1)
|
||||
sigma = scheduler_state.sigmas[step_index]
|
||||
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
|
||||
sigma = state.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -694,24 +693,6 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 20
|
||||
|
||||
def test_stable_diffusion_low_cpu_mem_usage(self):
|
||||
pipeline_id = "stabilityai/stable-diffusion-2-base"
|
||||
|
||||
start_time = time.time()
|
||||
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline_low_cpu_mem_usage.to(torch_device)
|
||||
low_cpu_mem_usage_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
|
||||
)
|
||||
normal_load_time = time.time() - start_time
|
||||
|
||||
assert 2 * low_cpu_mem_usage_time < normal_load_time
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
|
||||
Reference in New Issue
Block a user