diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index c56f4af1b8..b367fc7c62 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -117,84 +117,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def prepare_latents_img2img( - vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True -): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}") - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: - latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: - latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if vae.config.force_upcast: - image = image.float() - vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(vae.encode(image), generator=generator) - - if vae.config.force_upcast: - vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std - else: - init_latents = vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -419,8 +341,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): InputParam("denoising_end"), InputParam("strength", default=0.3), InputParam("denoising_start"), - # YiYi TODO: do we need num_images_per_prompt here? - InputParam("num_images_per_prompt", default=1), ] @property @@ -495,7 +415,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( components.scheduler, @@ -512,14 +432,12 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): components, block_state.num_inference_steps, block_state.strength, - block_state.device, + device, denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, ) - block_state.latent_timestep = block_state.timesteps[:1].repeat( - block_state.batch_size * block_state.num_images_per_prompt - ) + block_state.latent_timestep = block_state.timesteps[:1] if ( block_state.denoising_end is not None @@ -527,14 +445,14 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): and block_state.denoising_end > 0 and block_state.denoising_end < 1 ): - block_state.discrete_timestep_cutoff = int( + discrete_timestep_cutoff = int( round( components.scheduler.config.num_train_timesteps - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) block_state.num_inference_steps = len( - list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps)) ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] @@ -596,14 +514,14 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): and block_state.denoising_end > 0 and block_state.denoising_end < 1 ): - block_state.discrete_timestep_cutoff = int( + discrete_timestep_cutoff = int( round( components.scheduler.config.num_train_timesteps - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) block_state.num_inference_steps = len( - list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps)) ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] @@ -627,7 +545,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), InputParam( @@ -654,7 +571,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): ), InputParam( "latent_timestep", - required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", ), @@ -691,209 +607,99 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): ), ] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - @staticmethod - def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument - def prepare_latents_inpaint( + def prepare_latents( self, components, - batch_size, - num_channels_latents, - height, - width, + image_latents, dtype, device, generator, - latents=None, - image=None, timestep=None, is_strength_max=True, add_noise=True, - return_noise=False, - return_image_latents=False, ): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) + + batch_size = image_latents.shape[0] + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(components, image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if add_noise: + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * components.scheduler.init_noise_sigma + else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) - outputs = (latents,) + return latents, noise - if return_noise: - outputs += (noise,) - if return_image_latents: - outputs += (image_latents,) - - return outputs - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents + + def check_inputs(self, image_latents, mask, masked_image_latents): + if image_latents.shape[0] != 1: + raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}") + if mask.shape[0] != 1: + raise ValueError(f"mask should have have batch size 1, but got {mask.shape[0]}") + if masked_image_latents is not None and masked_image_latents.shape[0] != 1: + raise ValueError(f"masked_image_latents should have have batch size 1, but got {masked_image_latents.shape[0]}") + + if latent_timestep is not None and len(latent_timestep.shape) > 0: + raise ValueError(f"latent_timestep should be a scalar, but got {latent_timestep.shape}") + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + + self.check_inputs(block_state.image_latents, block_state.mask, block_state.masked_image_latents, block_state.latent_timestep) - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device + dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype + device = components._execution_device - block_state.is_strength_max = block_state.strength == 1.0 - - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(components, "unet") and components.unet is not None: - if components.unet.config.in_channels == 4: - block_state.masked_image_latents = None - - block_state.add_noise = True if block_state.denoising_start is None else False - - block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor - block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor - - block_state.latents, block_state.noise = self.prepare_latents_inpaint( - components, - block_state.batch_size * block_state.num_images_per_prompt, - components.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - image=block_state.image_latents, - timestep=block_state.latent_timestep, - is_strength_max=block_state.is_strength_max, - add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, - ) + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + _, _, height_latents, width_latents = block_state.image_latents.shape + + block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) + block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1) # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image_latents, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, + block_state.mask = torch.nn.functional.interpolate( + block_state.mask, size=(height_latents, width_latents) ) + block_state.mask = block_state.mask.to(device=device, dtype=dtype) + block_state.mask = block_state.mask.repeat(final_batch_size, 1, 1, 1) + + if block_state.masked_image_latents is not None: + block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype) + block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size, 1, 1, 1) + + if block_state.latent_timestep is not None: + block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) + block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype) + + is_strength_max = block_state.strength == 1.0 + add_noise = True if block_state.denoising_start is None else False + + block_state.latents, block_state.noise = self.prepare_latents( + components=components, + image_latents=block_state.image_latents, + dtype=dtype, + device=device, + generator=block_state.generator, + timestep=block_state.latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + ) + self.set_block_state(state, block_state) @@ -906,7 +712,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", EulerDiscreteScheduler), ] @@ -917,7 +722,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), ] @@ -928,7 +732,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): InputParam("generator"), InputParam( "latent_timestep", - required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", ), @@ -944,7 +747,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), ] @property @@ -954,27 +757,52 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ) ] + + def check_inputs(self, image_latents): + if image_latents.shape[0] != 1: + raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}") + + def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None): + if isinstance(generator, list) and len(generator) != image_latents.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {image_latents.shape[0]}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) + latents = scheduler.add_noise(image_latents, noise, timestep) + + return latents @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - block_state.add_noise = True if block_state.denoising_start is None else False - if block_state.latents is None: - block_state.latents = prepare_latents_img2img( - components.vae, - components.scheduler, + dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype + device = components._execution_device + + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + + block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) + block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1) + + if block_state.latent_timestep is not None: + block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) + block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype) + + add_noise = True if block_state.denoising_start is None else False + + if add_noise: + block_state.latents = prepare_latents( block_state.image_latents, - block_state.latent_timestep, - block_state.batch_size, - block_state.num_images_per_prompt, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.add_noise, + components.scheduler, + timestep=block_state.latent_timestep, + dtype=dtype, + device=device, + generator=block_state.generator, ) + else: + block_state.latents = block_state.image_latents self.set_block_state(state, block_state) @@ -988,7 +816,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("vae", AutoencoderKL), ] @property @@ -1026,15 +853,15 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): ] @staticmethod - def check_inputs(components, block_state): + def check_inputs(components, height, width): if ( - block_state.height is not None - and block_state.height % components.vae_scale_factor != 0 - or block_state.width is not None - and block_state.width % components.vae_scale_factor != 0 + height is not None + and height % components.vae_scale_factor != 0 + or width is not None + and width % components.vae_scale_factor != 0 ): raise ValueError( - f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {height} and {width}." ) @staticmethod @@ -1065,26 +892,27 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - if block_state.dtype is None: - block_state.dtype = components.vae.dtype + dtype = block_state.dtype + if dtype is None: + dtype = components.unet.dtype if hasattr(components, "unet") else torch.float32 - block_state.device = components._execution_device + device = components._execution_device - self.check_inputs(components, block_state) + self.check_inputs(components, block_state.height, block_state.width) + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + height = block_state.height or components.default_sample_size * components.vae_scale_factor + width = block_state.width or components.default_sample_size * components.vae_scale_factor - block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor - block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor - block_state.num_channels_latents = components.num_channels_latents block_state.latents = self.prepare_latents( - components, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, + comp=components, + batch_size=final_batch_size, + num_channels_latents=components.num_channels_latents, + height=height, + width=width, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, ) self.set_block_state(state, block_state) @@ -1103,15 +931,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config", - ), - ] + return [ComponentSpec("unet", UNet2DConditionModel),] @property def description(self) -> str: @@ -1129,6 +949,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): InputParam("num_images_per_prompt", default=1), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), + InputParam("embedded_guidance_scale", default=7.5), ] @property @@ -1259,18 +1080,20 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device + dtype = block_state.pooled_prompt_embeds.dtype - block_state.vae_scale_factor = components.vae_scale_factor + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * block_state.vae_scale_factor - block_state.width = block_state.width * block_state.vae_scale_factor + # define original_size/negative_original_size/target_size/negative_target_size + # - they are all defaulted to None + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + block_state.original_size = block_state.original_size or (height, width) + block_state.target_size = block_state.target_size or (height, width) if block_state.negative_original_size is None: block_state.negative_original_size = block_state.original_size @@ -1287,15 +1110,11 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): block_state.negative_original_size, block_state.negative_crops_coords_top_left, block_state.negative_target_size, - dtype=block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) - block_state.add_time_ids = block_state.add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) + block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device) # Optionally get Guidance Scale Embedding for LCM block_state.timestep_cond = None @@ -1305,12 +1124,10 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( - block_state.batch_size * block_state.num_images_per_prompt - ) + block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device) block_state.timestep_cond = self.get_guidance_scale_embedding( block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) + ).to(device=device, dtype=dtype) self.set_block_state(state, block_state) return components, state @@ -1325,15 +1142,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config", - ), - ] + return [ComponentSpec("unet", UNet2DConditionModel),] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1345,6 +1154,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), + InputParam("embedded_guidance_scale", default=7.5), ] @property @@ -1442,24 +1252,26 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device + dtype = block_state.pooled_prompt_embeds.dtype + text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor block_state.original_size = block_state.original_size or (block_state.height, block_state.width) block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) block_state.add_time_ids = self._get_add_time_ids( components, block_state.original_size, block_state.crops_coords_top_left, block_state.target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if block_state.negative_original_size is not None and block_state.negative_target_size is not None: block_state.negative_add_time_ids = self._get_add_time_ids( @@ -1467,18 +1279,14 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): block_state.negative_original_size, block_state.negative_crops_coords_top_left, block_state.negative_target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: block_state.negative_add_time_ids = block_state.add_time_ids - block_state.add_time_ids = block_state.add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) + block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device) # Optionally get Guidance Scale Embedding for LCM block_state.timestep_cond = None @@ -1488,9 +1296,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( - block_state.batch_size * block_state.num_images_per_prompt - ) + block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device) block_state.timestep_cond = self.get_guidance_scale_embedding( block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=block_state.device, dtype=block_state.latents.dtype) @@ -1613,14 +1419,18 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # (1) prepare controlnet inputs - block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(components.controlnet) + device = components._execution_device + dtype = components.controlnet.dtype + + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # (1) prepare controlnet inputs + # (1.1) # control_guidance_start/control_guidance_end (align format) if not isinstance(block_state.control_guidance_start, list) and isinstance( @@ -1670,12 +1480,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): block_state.control_image = self.prepare_control_image( components, image=block_state.control_image, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, + width=width, + height=height, + batch_size=final_batch_size, num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, + device=device, + dtype=dtype, crops_coords=block_state.crops_coords, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1685,12 +1495,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): control_image = self.prepare_control_image( components, image=control_image_, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, + width=width, + height=height, + batch_size=final_batch_size, num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, + device=device, + dtype=dtype, crops_coords=block_state.crops_coords, ) @@ -1852,9 +1662,10 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): device = components._execution_device dtype = block_state.dtype or components.controlnet.dtype - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt # control_guidance_start/control_guidance_end (align format) if not isinstance(block_state.control_guidance_start, list) and isinstance( @@ -1900,15 +1711,15 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): block_state.control_image[idx] = self.prepare_control_image( components, image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, + width=width, + height=height, + batch_size=final_batch_size, num_images_per_prompt=block_state.num_images_per_prompt, device=device, dtype=dtype, crops_coords=block_state.crops_coords, ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + _, _, height, width = block_state.control_image[idx].shape # controlnet_keep block_state.controlnet_keep = [] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index bd0e962140..99a677dfe6 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -813,29 +813,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) return mask, masked_image_latents