mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
unet's sample_size attribute is to accept tuple(h, w) in StableDiffusionPipeline (#10181)
This commit is contained in:
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
|
||||
@@ -255,7 +255,12 @@ class StableDiffusionPipeline(
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
|
||||
is_unet_sample_size_less_64 = (
|
||||
hasattr(unet.config, "sample_size")
|
||||
and self._is_unet_config_sample_size_int
|
||||
and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -902,8 +907,18 @@ class StableDiffusionPipeline(
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
if not height or not width:
|
||||
height = (
|
||||
self.unet.config.sample_size
|
||||
if self._is_unet_config_sample_size_int
|
||||
else self.unet.config.sample_size[0]
|
||||
)
|
||||
width = (
|
||||
self.unet.config.sample_size
|
||||
if self._is_unet_config_sample_size_int
|
||||
else self.unet.config.sample_size[1]
|
||||
)
|
||||
height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
|
||||
# to deal with lora scaling and other possible forward hooks
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
|
||||
@@ -840,6 +840,14 @@ class StableDiffusionPipelineFastTests(
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
def test_pipeline_accept_tuple_type_unet_sample_size(self):
|
||||
# the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
|
||||
sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
sample_size = [60, 80]
|
||||
customised_unet = UNet2DConditionModel(sample_size=sample_size)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
|
||||
assert pipe.unet.config.sample_size == sample_size
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user