mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix deterministic issue when getting pipeline dtype and device (#10696)
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -1610,7 +1610,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
expected_modules.add(name)
|
||||
optional_parameters.remove(name)
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
return sorted(expected_modules), sorted(optional_parameters)
|
||||
|
||||
@classmethod
|
||||
def _get_signature_types(cls):
|
||||
@@ -1652,10 +1652,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
||||
}
|
||||
|
||||
if set(components.keys()) != expected_modules:
|
||||
actual = sorted(set(components.keys()))
|
||||
expected = sorted(expected_modules)
|
||||
if actual != expected:
|
||||
raise ValueError(
|
||||
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
||||
f" {expected_modules} to be defined, but {components.keys()} are defined."
|
||||
f" {expected} to be defined, but {actual} are defined."
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@@ -19,7 +19,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
||||
|
||||
|
||||
class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
@@ -826,3 +826,104 @@ class ProgressBarTests(unittest.TestCase):
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
|
||||
expected_pipe_device = torch.device("cuda:0")
|
||||
expected_pipe_dtype = torch.float64
|
||||
|
||||
def get_dummy_components_image_generation(self):
|
||||
cross_attention_dim = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=1,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=cross_attention_dim,
|
||||
intermediate_size=16,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_deterministic_device(self):
|
||||
components = self.get_dummy_components_image_generation()
|
||||
|
||||
pipe = StableDiffusionPipeline(**components)
|
||||
pipe.to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
pipe.unet.to(device="cpu")
|
||||
pipe.vae.to(device="cuda")
|
||||
pipe.text_encoder.to(device="cuda:0")
|
||||
|
||||
pipe_device = pipe.device
|
||||
|
||||
self.assertEqual(
|
||||
self.expected_pipe_device,
|
||||
pipe_device,
|
||||
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
|
||||
)
|
||||
|
||||
def test_deterministic_dtype(self):
|
||||
components = self.get_dummy_components_image_generation()
|
||||
|
||||
pipe = StableDiffusionPipeline(**components)
|
||||
pipe.to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
pipe.unet.to(dtype=torch.float16)
|
||||
pipe.vae.to(dtype=torch.float32)
|
||||
pipe.text_encoder.to(dtype=torch.float64)
|
||||
|
||||
pipe_dtype = pipe.dtype
|
||||
|
||||
self.assertEqual(
|
||||
self.expected_pipe_dtype,
|
||||
pipe_dtype,
|
||||
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user