From e97a633b63942c8dd0fbc54eb8defbfa559b8161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Tolga=20Cang=C3=B6z?= <46008593+standardAI@users.noreply.github.com> Date: Mon, 18 Mar 2024 21:53:29 +0300 Subject: [PATCH] Update access of configuration attributes (#7343) Co-authored-by: Sayak Paul --- docs/source/ko/optimization/fp16.md | 2 +- docs/source/ko/using-diffusers/write_own_pipeline.md | 2 +- examples/community/stable_diffusion_ipex.py | 4 ++-- .../community/stable_diffusion_tensorrt_txt2img.py | 2 +- scripts/convert_if.py | 6 +++--- .../pipelines/deepfloyd_if/pipeline_if_img2img.py | 2 +- .../pipeline_if_img2img_superresolution.py | 2 +- .../pipelines/deepfloyd_if/pipeline_if_inpainting.py | 4 ++-- .../pipeline_if_inpainting_superresolution.py | 4 ++-- .../pipeline_onnx_stable_diffusion_upscale.py | 8 ++++---- .../pipeline_stable_diffusion_gligen.py | 4 ++-- .../pipeline_stable_diffusion_gligen_text_image.py | 2 +- .../schedulers/scheduling_consistency_models.py | 2 +- .../scheduling_dpmsolver_multistep_inverse.py | 2 +- src/diffusers/schedulers/scheduling_dpmsolver_sde.py | 2 +- .../schedulers/scheduling_euler_discrete.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- tests/models/unets/test_models_unet_2d_condition.py | 4 ++-- tests/others/test_config.py | 12 ++++++------ .../dance_diffusion/test_dance_diffusion.py | 4 ++-- 20 files changed, 36 insertions(+), 36 deletions(-) diff --git a/docs/source/ko/optimization/fp16.md b/docs/source/ko/optimization/fp16.md index 863f0fea2d..2e58421c35 100644 --- a/docs/source/ko/optimization/fp16.md +++ b/docs/source/ko/optimization/fp16.md @@ -355,7 +355,7 @@ unet_traced = torch.jit.load("unet_traced.pt") class TracedUNet(torch.nn.Module): def __init__(self): super().__init__() - self.in_channels = pipe.unet.in_channels + self.in_channels = pipe.unet.config.in_channels self.device = pipe.unet.device def forward(self, latent_model_input, t, encoder_hidden_states): diff --git a/docs/source/ko/using-diffusers/write_own_pipeline.md b/docs/source/ko/using-diffusers/write_own_pipeline.md index b56da2e5b4..772db1b4f4 100644 --- a/docs/source/ko/using-diffusers/write_own_pipeline.md +++ b/docs/source/ko/using-diffusers/write_own_pipeline.md @@ -210,7 +210,7 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di ```py >>> latents = torch.randn( -... (batch_size, unet.in_channels, height // 8, width // 8), +... (batch_size, unet.config.in_channels, height // 8, width // 8), ... generator=generator, ... device=torch_device, ... ) diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py index 8e71f79e9a..3b5ed09aa1 100644 --- a/examples/community/stable_diffusion_ipex.py +++ b/examples/community/stable_diffusion_ipex.py @@ -224,7 +224,7 @@ class StableDiffusionIPEXPipeline( # 5. Prepare latent variables latents = self.prepare_latents( batch_size * num_images_per_prompt, - self.unet.in_channels, + self.unet.config.in_channels, height, width, prompt_embeds.dtype, @@ -679,7 +679,7 @@ class StableDiffusionIPEXPipeline( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index 54661d66a2..1fcfafadb4 100755 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -917,7 +917,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): text_embeddings = self.__encode_prompt(prompt, negative_prompt) # Pre-initialize latents - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size, num_channels_latents, diff --git a/scripts/convert_if.py b/scripts/convert_if.py index c4588f4b25..85c739ca92 100644 --- a/scripts/convert_if.py +++ b/scripts/convert_if.py @@ -1195,9 +1195,9 @@ def superres_check_against_original(dump_path, unet_checkpoint_path): if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model batch_size = 1 - channels = model.in_channels // 2 - height = model.sample_size - width = model.sample_size + channels = model.config.in_channels // 2 + height = model.config.sample_size + width = model.config.sample_size height = 1024 width = 1024 diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py index ccc7b1d151..99633ee215 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -613,7 +613,7 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index b4ce5831a5..19c4f1d390 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -662,7 +662,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py index 180e5309c5..66a185b24f 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -654,7 +654,7 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 @@ -701,7 +701,7 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin): for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") - mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = resize(mask_image_, self.unet.config.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py index b67907c1c1..5c01dfdc29 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -698,7 +698,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 @@ -778,7 +778,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") - mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = resize(mask_image_, self.unet.config.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 58d83de0d3..bee6ea7b11 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -469,7 +469,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): latents = self.prepare_latents( batch_size * num_images_per_prompt, - self.num_latent_channels, + self.config.num_latent_channels, height, width, latents_dtype, @@ -498,12 +498,12 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] - if self.num_latent_channels + num_channels_image != self.num_unet_input_channels: + if self.config.num_latent_channels + num_channels_image != self.config.num_unet_input_channels: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" - f" {self.num_unet_input_channels} but received `num_channels_latents`: {self.num_latent_channels} +" + f" {self.config.num_unet_input_channels} but received `num_channels_latents`: {self.config.num_latent_channels} +" f" `num_channels_image`: {num_channels_image} " - f" = {self.num_latent_channels + num_channels_image}. Please verify the config of" + f" = {self.config.num_latent_channels + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py index 9f0d1190fd..e0b40487ac 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -680,7 +680,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin): timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -713,7 +713,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin): boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) boxes[:n_objs] = torch.tensor(gligen_boxes) text_embeddings = torch.zeros( - max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype + max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype ) text_embeddings[:n_objs] = _text_embeddings # Generate a mask for each object that is entity described by phrases diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py index 296ecae653..3570cdce99 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -847,7 +847,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index cea51f5c0c..5a37886e22 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -233,7 +233,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): sigmas = self._convert_to_karras(ramp) timesteps = self.sigma_to_t(sigmas) - sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32) + sigmas = np.concatenate([sigmas, [self.config.sigma_min]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 40ab394e3e..428eaea6a6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -233,7 +233,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item() + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item() self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index d9f1f6f4bd..96962c315e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -325,7 +325,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.use_karras_sigmas: + if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 78511e4c54..afe2d1456e 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -343,7 +343,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): " 'linear' or 'log_linear'" ) - if self.use_karras_sigmas: + if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index ee101cf2eb..9f759683b5 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -288,7 +288,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.use_karras_sigmas: + if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index db07b126e4..a19e8f8c65 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -782,7 +782,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test # update inputs_dict for ip-adapter batch_size = inputs_dict["encoder_hidden_states"].shape[0] # for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim] - image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) + image_embeds = floats_tensor((batch_size, 1, model.config.cross_attention_dim)).to(torch_device) inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]} # make ip_adapter_1 and ip_adapter_2 @@ -854,7 +854,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test # update inputs_dict for ip-adapter batch_size = inputs_dict["encoder_hidden_states"].shape[0] # for ip-adapter-plus image_embeds has shape [batch_size, num_image, sequence_length, embed_dim] - image_embeds = floats_tensor((batch_size, 1, 1, model.cross_attention_dim)).to(torch_device) + image_embeds = floats_tensor((batch_size, 1, 1, model.config.cross_attention_dim)).to(torch_device) inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]} # make ip_adapter_1 and ip_adapter_2 diff --git a/tests/others/test_config.py b/tests/others/test_config.py index 3492ec3508..0795796275 100644 --- a/tests/others/test_config.py +++ b/tests/others/test_config.py @@ -272,17 +272,17 @@ class ConfigTester(unittest.TestCase): # now loading it with SampleObject2 should put f into `_use_default_values` config = SampleObject2.from_config(tmpdirname) - assert "f" in config._use_default_values - assert config.f == [1, 3] + assert "f" in config.config._use_default_values + assert config.config.f == [1, 3] # now loading the config, should **NOT** use [1, 3] for `f`, but the default [1, 4] value - # **BECAUSE** it is part of `config._use_default_values` + # **BECAUSE** it is part of `config.config._use_default_values` new_config = SampleObject4.from_config(config.config) - assert new_config.f == [5, 4] + assert new_config.config.f == [5, 4] config.config._use_default_values.pop() new_config_2 = SampleObject4.from_config(config.config) - assert new_config_2.f == [1, 3] + assert new_config_2.config.f == [1, 3] # Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5] - assert new_config_2.e == [1, 3] + assert new_config_2.config.e == [1, 3] diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index 212505c9ed..e40d813016 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -137,7 +137,7 @@ class PipelineIntegrationTests(unittest.TestCase): audio_slice = audio[0, -3:, -3:] - assert audio.shape == (1, 2, pipe.unet.sample_size) + assert audio.shape == (1, 2, pipe.unet.config.sample_size) expected_slice = np.array([-0.0192, -0.0231, -0.0318, -0.0059, 0.0002, -0.0020]) assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2 @@ -155,7 +155,7 @@ class PipelineIntegrationTests(unittest.TestCase): audio_slice = audio[0, -3:, -3:] - assert audio.shape == (1, 2, pipe.unet.sample_size) + assert audio.shape == (1, 2, pipe.unet.config.sample_size) expected_slice = np.array([-0.0367, -0.0488, -0.0771, -0.0525, -0.0444, -0.0341]) assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2