diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5ae9e63851..0d068f90f7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -140,174 +140,6 @@ def retrieve_latents( -# YiYi Notes: I think we do not need this, we can add loader methods on the components class -class StableDiffusionXLLoraStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc" - " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("unet", UNet2DConditionModel), - ] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") - - -class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - PipelineImageInput, - required=True, - description="The image(s) to be used as ip adapter" - ) - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(components.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = components.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = components.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = components.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 - data.device = pipeline._execution_device - - data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - pipeline, - ip_adapter_image=data.ip_adapter_image, - ip_adapter_image_embeds=None, - device=data.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=data.prepare_unconditional_embeds, - ) - if data.prepare_unconditional_embeds: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, data) - return pipeline, state - - class StableDiffusionXLTextEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl"