1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix deterministic issue when getting pipeline dtype and device (#10696)

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Dimitri Barbot
2025-03-15 03:20:58 +01:00
committed by GitHub
parent 6b9a3334db
commit be54a95b93
2 changed files with 107 additions and 4 deletions

View File

@@ -1610,7 +1610,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
expected_modules.add(name)
optional_parameters.remove(name)
return expected_modules, optional_parameters
return sorted(expected_modules), sorted(optional_parameters)
@classmethod
def _get_signature_types(cls):
@@ -1652,10 +1652,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
if set(components.keys()) != expected_modules:
actual = sorted(set(components.keys()))
expected = sorted(expected_modules)
if actual != expected:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components.keys()} are defined."
f" {expected} to be defined, but {actual} are defined."
)
return components

View File

@@ -19,7 +19,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
from diffusers.utils.testing_utils import torch_device
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -826,3 +826,104 @@ class ProgressBarTests(unittest.TestCase):
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
@require_torch_gpu
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
expected_pipe_device = torch.device("cuda:0")
expected_pipe_dtype = torch.float64
def get_dummy_components_image_generation(self):
cross_attention_dim = 8
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=16,
layer_norm_eps=1e-05,
num_attention_heads=2,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
def test_deterministic_device(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(device="cpu")
pipe.vae.to(device="cuda")
pipe.text_encoder.to(device="cuda:0")
pipe_device = pipe.device
self.assertEqual(
self.expected_pipe_device,
pipe_device,
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
)
def test_deterministic_dtype(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(dtype=torch.float16)
pipe.vae.to(dtype=torch.float32)
pipe.text_encoder.to(dtype=torch.float64)
pipe_dtype = pipe.dtype
self.assertEqual(
self.expected_pipe_dtype,
pipe_dtype,
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
)