From b756ec6e80b3d94c3ae7dc356bdbbdb426a05dca Mon Sep 17 00:00:00 2001 From: djm <92705171+Foundsheep@users.noreply.github.com> Date: Fri, 20 Dec 2024 07:24:18 +0900 Subject: [PATCH] unet's `sample_size` attribute is to accept tuple(h, w) in `StableDiffusionPipeline` (#10181) --- .../models/unets/unet_2d_condition.py | 2 +- .../pipeline_stable_diffusion.py | 21 ++++++++++++++++--- .../stable_diffusion/test_stable_diffusion.py | 8 +++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 4f55df32b7..e488f5897e 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -170,7 +170,7 @@ class UNet2DConditionModel( @register_to_config def __init__( self, - sample_size: Optional[int] = None, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4fd6a43a95..ac6c8253e4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -255,7 +255,12 @@ class StableDiffusionPipeline( is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") + and self._is_unet_config_sample_size_int + and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -902,8 +907,18 @@ class StableDiffusionPipeline( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + if not height or not width: + height = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[0] + ) + width = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[1] + ) + height, width = height * self.vae_scale_factor, width * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index f37d598c83..ccd5567106 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -840,6 +840,14 @@ class StableDiffusionPipelineFastTests( # they should be the same assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + def test_pipeline_accept_tuple_type_unet_sample_size(self): + # the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size + sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" + sample_size = [60, 80] + customised_unet = UNet2DConditionModel(sample_size=sample_size) + pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet) + assert pipe.unet.config.sample_size == sample_size + @slow @require_torch_gpu