mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove lora step and ip-adapter step -> no longer needed
This commit is contained in:
@@ -140,174 +140,6 @@ def retrieve_latents(
|
||||
|
||||
|
||||
|
||||
# YiYi Notes: I think we do not need this, we can add loader methods on the components class
|
||||
class StableDiffusionXLLoraStep(PipelineBlock):
|
||||
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc"
|
||||
" See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)"
|
||||
" for more details"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", CLIPTextModel),
|
||||
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc"
|
||||
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
|
||||
" for more details"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
|
||||
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"ip_adapter_image",
|
||||
PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to be used as ip adapter"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
|
||||
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
|
||||
]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
|
||||
def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(components.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = components.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = components.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = components.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
|
||||
):
|
||||
image_embeds = []
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
components, single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
||||
else:
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
negative_image_embeds.append(single_negative_image_embeds)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
||||
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
data = self.get_block_state(state)
|
||||
|
||||
data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1
|
||||
data.device = pipeline._execution_device
|
||||
|
||||
data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
pipeline,
|
||||
ip_adapter_image=data.ip_adapter_image,
|
||||
ip_adapter_image_embeds=None,
|
||||
device=data.device,
|
||||
num_images_per_prompt=1,
|
||||
prepare_unconditional_embeds=data.prepare_unconditional_embeds,
|
||||
)
|
||||
if data.prepare_unconditional_embeds:
|
||||
data.negative_ip_adapter_embeds = []
|
||||
for i, image_embeds in enumerate(data.ip_adapter_embeds):
|
||||
negative_image_embeds, image_embeds = image_embeds.chunk(2)
|
||||
data.negative_ip_adapter_embeds.append(negative_image_embeds)
|
||||
data.ip_adapter_embeds[i] = image_embeds
|
||||
|
||||
self.add_block_state(state, data)
|
||||
return pipeline, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
Reference in New Issue
Block a user