1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

unet's sample_size attribute is to accept tuple(h, w) in StableDiffusionPipeline (#10181)

This commit is contained in:
djm
2024-12-20 07:24:18 +09:00
committed by GitHub
parent d8825e7697
commit b756ec6e80
3 changed files with 27 additions and 4 deletions

View File

@@ -170,7 +170,7 @@ class UNet2DConditionModel(
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,

View File

@@ -255,7 +255,12 @@ class StableDiffusionPipeline(
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
is_unet_sample_size_less_64 = (
hasattr(unet.config, "sample_size")
and self._is_unet_config_sample_size_int
and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -902,8 +907,18 @@ class StableDiffusionPipeline(
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if not height or not width:
height = (
self.unet.config.sample_size
if self._is_unet_config_sample_size_int
else self.unet.config.sample_size[0]
)
width = (
self.unet.config.sample_size
if self._is_unet_config_sample_size_int
else self.unet.config.sample_size[1]
)
height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct

View File

@@ -840,6 +840,14 @@ class StableDiffusionPipelineFastTests(
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
def test_pipeline_accept_tuple_type_unet_sample_size(self):
# the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
sample_size = [60, 80]
customised_unet = UNet2DConditionModel(sample_size=sample_size)
pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
assert pipe.unet.config.sample_size == sample_size
@slow
@require_torch_gpu