mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
IPAdapterTesterMixin (#6862)
* begin IPAdapterTesterMixin --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -797,7 +797,11 @@ class AnimateDiffPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
||||
for free_init_iter in range(num_free_init_iters):
|
||||
|
||||
@@ -441,6 +441,41 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack(
|
||||
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
image_embeds = ip_adapter_image_embeds
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
@@ -735,6 +770,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -784,6 +820,9 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
|
||||
`np.array`.
|
||||
@@ -870,13 +909,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
@@ -902,7 +938,11 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
||||
for free_init_iter in range(num_free_init_iters):
|
||||
|
||||
@@ -1206,7 +1206,11 @@ class StableDiffusionControlNetPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 7.2 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
|
||||
@@ -1206,7 +1206,11 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 7.2 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
|
||||
@@ -1495,7 +1495,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 7.2 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
|
||||
@@ -477,8 +477,9 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, do_classifier_free_guidance, device, num_images_per_prompt
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
@@ -502,7 +503,7 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
@@ -699,6 +700,10 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@@ -845,7 +850,7 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, ip_adapter_image_embeds, False, device, batch_size * num_images_per_prompt
|
||||
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
|
||||
# 3. Encode input prompt
|
||||
@@ -860,7 +865,7 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
False,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -906,7 +911,11 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
|
||||
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. LCM Multistep Sampling Loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -461,6 +461,41 @@ class LatentConsistencyModelPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack(
|
||||
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
image_embeds = ip_adapter_image_embeds
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -590,6 +625,10 @@ class LatentConsistencyModelPipeline(
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@@ -610,6 +649,7 @@ class LatentConsistencyModelPipeline(
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -660,6 +700,9 @@ class LatentConsistencyModelPipeline(
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -726,12 +769,10 @@ class LatentConsistencyModelPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
|
||||
# 3. Encode input prompt
|
||||
@@ -746,7 +787,7 @@ class LatentConsistencyModelPipeline(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
False,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -786,7 +827,11 @@ class LatentConsistencyModelPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
|
||||
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. LCM MultiStep Sampling Loop:
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -987,7 +987,11 @@ class PIAPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
||||
|
||||
@@ -1111,7 +1111,11 @@ class StableDiffusionImg2ImgPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 7.2 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
|
||||
@@ -1397,7 +1397,11 @@ class StableDiffusionInpaintPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 9.2 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
|
||||
@@ -777,7 +777,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
# Each denoising step also includes refinement of the latents with respect to the
|
||||
|
||||
@@ -62,7 +62,10 @@ def create_ip_adapter_state_dict(model):
|
||||
key_id = 1
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
cross_attention_dim = (
|
||||
None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim
|
||||
)
|
||||
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -71,6 +74,7 @@ def create_ip_adapter_state_dict(model):
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
sd = IPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
@@ -28,7 +28,7 @@ def to_np(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class AnimateDiffPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AnimateDiffPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
@@ -28,7 +28,7 @@ def to_np(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class AnimateDiffVideoToVideoPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AnimateDiffVideoToVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS
|
||||
|
||||
@@ -54,6 +54,7 @@ from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
@@ -110,7 +111,11 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
|
||||
|
||||
class ControlNetPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -273,7 +278,7 @@ class ControlNetPipelineFastTests(
|
||||
|
||||
|
||||
class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -490,7 +495,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
|
||||
|
||||
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
|
||||
@@ -52,6 +52,7 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
@@ -62,7 +63,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class ControlNetImg2ImgPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
@@ -181,7 +186,7 @@ class ControlNetImg2ImgPipelineFastTests(
|
||||
|
||||
|
||||
class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
|
||||
@@ -51,11 +51,7 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
@@ -48,6 +48,7 @@ from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
@@ -59,6 +60,7 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipelineFastTests(
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
|
||||
@@ -36,6 +36,7 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
@@ -46,7 +47,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class ControlNetPipelineSDXLImg2ImgFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
|
||||
@@ -20,13 +20,15 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class LatentConsistencyModelPipelineFastTests(
|
||||
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = LatentConsistencyModelPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"negative_prompt", "negative_prompt_embeds"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"}
|
||||
|
||||
@@ -27,14 +27,14 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LatentConsistencyModelImg2ImgPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = LatentConsistencyModelImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "negative_prompt", "negative_prompt_embeds"}
|
||||
|
||||
@@ -17,7 +17,7 @@ from diffusers import (
|
||||
from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import floats_tensor, torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
@@ -27,7 +27,7 @@ def to_np(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PIAPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
|
||||
@@ -23,7 +23,11 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import (
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -60,7 +64,12 @@ from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -100,7 +109,11 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
|
||||
|
||||
class StableDiffusionPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -177,7 +190,7 @@ class StableDiffusionPipelineFastTests(
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
|
||||
@@ -55,7 +55,12 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -94,7 +99,11 @@ def _test_img2img_compile(in_queue, out_queue, timeout):
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
|
||||
@@ -57,7 +57,12 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -98,7 +103,11 @@ def _test_inpaint_compile(in_queue, out_queue, timeout):
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
|
||||
@@ -47,7 +47,11 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
@@ -49,14 +49,23 @@ from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionXLPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
|
||||
@@ -44,6 +44,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
assert_mean_pixel_difference,
|
||||
@@ -54,7 +55,7 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLAdapterPipelineFastTests(
|
||||
PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
IPAdapterTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLAdapterPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
|
||||
@@ -54,13 +54,20 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class StableDiffusionXLImg2ImgPipelineFastTests(
|
||||
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
|
||||
@@ -48,13 +48,15 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class StableDiffusionXLInpaintPipelineFastTests(
|
||||
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
|
||||
@@ -8,7 +8,7 @@ import re
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Callable, Union
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -29,6 +29,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
|
||||
@@ -44,6 +45,7 @@ from ..models.autoencoders.test_models_vae import (
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@@ -59,6 +61,118 @@ def check_same_shape(tensor_list):
|
||||
return all(shape == shapes[0] for shape in shapes[1:])
|
||||
|
||||
|
||||
class IPAdapterTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
|
||||
It provides a set of common tests for pipelines that support IP Adapters.
|
||||
"""
|
||||
|
||||
def test_pipeline_signature(self):
|
||||
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
||||
|
||||
assert issubclass(self.pipeline_class, IPAdapterMixin)
|
||||
self.assertIn(
|
||||
"ip_adapter_image",
|
||||
parameters,
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method",
|
||||
)
|
||||
self.assertIn(
|
||||
"ip_adapter_image_embeds",
|
||||
parameters,
|
||||
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
||||
if "image" in parameters.keys() and "strength" in parameters.keys():
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "np"
|
||||
inputs["return_dict"] = False
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self, expected_max_diff: float = 1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs)[0]
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs)[0]
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
self.assertLess(
|
||||
max_diff_without_adapter_scale,
|
||||
expected_max_diff,
|
||||
"Output without ip-adapter must be same as normal inference",
|
||||
)
|
||||
self.assertGreater(
|
||||
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
def test_ip_adapter_multi(self, expected_max_diff: float = 1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs)[0]
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs)[0]
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
self.assertLess(
|
||||
max_diff_without_multi_adapter_scale,
|
||||
expected_max_diff,
|
||||
"Output without multi-ip-adapter must be same as normal inference",
|
||||
)
|
||||
self.assertGreater(
|
||||
max_diff_with_multi_adapter_scale,
|
||||
1e-2,
|
||||
"Output with multi-ip-adapter scale must be different from normal inference",
|
||||
)
|
||||
|
||||
|
||||
class PipelineLatentTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
|
||||
|
||||
Reference in New Issue
Block a user