mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support InstantStyle (#7668)
* enable control ip-adapter per-transformer block on-the-fly --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com> Co-authored-by: ResearcherXman <xhs.research@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -640,3 +640,87 @@ image
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />
|
||||
</div>
|
||||
|
||||
### Style & layout control
|
||||
|
||||
[InstantStyle](https://arxiv.org/abs/2404.02733) is a plug-and-play method on top of IP-Adapter, which disentangles style and layout from image prompt to control image generation. This is achieved by only inserting IP-Adapters to some specific part of the model.
|
||||
|
||||
By default IP-Adapters are inserted to all layers of the model. Use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method with a dictionary to assign scales to IP-Adapter at different layers.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
```
|
||||
|
||||
This will activate IP-Adapter at the second layer in the model's down-part block 2 and up-part block 0. The former is the layer where IP-Adapter injects layout information and the latter injects style. Inserting IP-Adapter to these two layers you can generate images following the style and layout of image prompt, but with contents more aligned to text prompt.
|
||||
|
||||
```py
|
||||
style_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(42)
|
||||
image = pipeline(
|
||||
prompt="a cat, masterpiece, best quality, high quality",
|
||||
image=style_image,
|
||||
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
|
||||
guidance_scale=5,
|
||||
num_inference_steps=30,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit_style_layout_cat.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
In contrast, inserting IP-Adapter to all layers will often generate images that overly focus on image prompt and diminish diversity.
|
||||
|
||||
Activate IP-Adapter only in the style layer and then call the pipeline again.
|
||||
|
||||
```py
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(42)
|
||||
image = pipeline(
|
||||
prompt="a cat, masterpiece, best quality, high quality",
|
||||
image=style_image,
|
||||
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
|
||||
guidance_scale=5,
|
||||
num_inference_steps=30,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit_style_cat.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter only in style layer</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/30518dfe089e6bf50008875077b44cb98fb2065c/diffusers/default_out.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter in all layers</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Note that you don't have to specify all layers in the dictionary. Those not included in the dictionary will be set to scale 0 which means disable IP-Adapter by default.
|
||||
|
||||
@@ -28,6 +28,7 @@ from ..utils import (
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -243,25 +244,55 @@ class IPAdapterMixin:
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Sets the conditioning scale between text and image.
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
pipeline.set_ip_adapter_scale(0.5)
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(attn_processor.scale)
|
||||
if len(attn_processor.scale) != len(scale):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"`scale` should be a list of same length as the number if ip-adapters "
|
||||
f"Expected {len(attn_processor.scale)} but got {len(scale)}."
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
attn_processor.scale = scale
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
|
||||
@@ -38,7 +38,9 @@ def _translate_into_actual_layer_name(name):
|
||||
return ".".join((updown, block, attn))
|
||||
|
||||
|
||||
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]):
|
||||
def _maybe_expand_lora_scales(
|
||||
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
|
||||
):
|
||||
blocks_with_transformer = {
|
||||
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
|
||||
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
|
||||
@@ -47,7 +49,11 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[
|
||||
|
||||
expanded_weight_scales = [
|
||||
_maybe_expand_lora_scales_for_one_adapter(
|
||||
weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict()
|
||||
weight_for_adapter,
|
||||
blocks_with_transformer,
|
||||
transformer_per_block,
|
||||
unet.state_dict(),
|
||||
default_scale=default_scale,
|
||||
)
|
||||
for weight_for_adapter in weight_scales
|
||||
]
|
||||
@@ -60,6 +66,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
blocks_with_transformer: Dict[str, int],
|
||||
transformer_per_block: Dict[str, int],
|
||||
state_dict: None,
|
||||
default_scale: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Expands the inputs into a more granular dictionary. See the example below for more details.
|
||||
@@ -108,21 +115,36 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
scales = copy.deepcopy(scales)
|
||||
|
||||
if "mid" not in scales:
|
||||
scales["mid"] = 1
|
||||
scales["mid"] = default_scale
|
||||
elif isinstance(scales["mid"], list):
|
||||
if len(scales["mid"]) == 1:
|
||||
scales["mid"] = scales["mid"][0]
|
||||
else:
|
||||
raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
|
||||
|
||||
for updown in ["up", "down"]:
|
||||
if updown not in scales:
|
||||
scales[updown] = 1
|
||||
scales[updown] = default_scale
|
||||
|
||||
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
|
||||
if not isinstance(scales[updown], dict):
|
||||
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}
|
||||
scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
|
||||
|
||||
# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
|
||||
# eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
|
||||
for i in blocks_with_transformer[updown]:
|
||||
block = f"block_{i}"
|
||||
# set not assigned blocks to default scale
|
||||
if block not in scales[updown]:
|
||||
scales[updown][block] = default_scale
|
||||
if not isinstance(scales[updown][block], list):
|
||||
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
|
||||
elif len(scales[updown][block]) == 1:
|
||||
# a list specifying scale to each masked IP input
|
||||
scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
|
||||
elif len(scales[updown][block]) != transformer_per_block[updown]:
|
||||
raise ValueError(
|
||||
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
|
||||
)
|
||||
|
||||
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
|
||||
for i in blocks_with_transformer[updown]:
|
||||
|
||||
@@ -2229,44 +2229,51 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
||||
):
|
||||
if mask is not None:
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
skip = False
|
||||
if isinstance(scale, list):
|
||||
if all(s == 0 for s in scale):
|
||||
skip = True
|
||||
elif scale == 0:
|
||||
skip = True
|
||||
if not skip:
|
||||
if mask is not None:
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * mask.shape[1]
|
||||
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
||||
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@@ -2439,57 +2446,64 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
||||
):
|
||||
if mask is not None:
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
skip = False
|
||||
if isinstance(scale, list):
|
||||
if all(s == 0 for s in scale):
|
||||
skip = True
|
||||
elif scale == 0:
|
||||
skip = True
|
||||
if not skip:
|
||||
if mask is not None:
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * mask.shape[1]
|
||||
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
_current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
_current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
||||
@@ -73,7 +73,9 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase):
|
||||
image_processor = CLIPImageProcessor.from_pretrained(repo_id)
|
||||
return image_processor
|
||||
|
||||
def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False):
|
||||
def get_dummy_inputs(
|
||||
self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False, for_instant_style=False
|
||||
):
|
||||
image = load_image(
|
||||
"https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png"
|
||||
)
|
||||
@@ -126,6 +128,40 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
elif for_instant_style:
|
||||
composition_mask = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/1024_whole_mask.png"
|
||||
)
|
||||
female_mask = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125641_mask.png"
|
||||
)
|
||||
male_mask = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125344_mask.png"
|
||||
)
|
||||
background_mask = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_6_20240321130722_mask.png"
|
||||
)
|
||||
ip_composition_image = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125152.png"
|
||||
)
|
||||
ip_female_style = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125625.png"
|
||||
)
|
||||
ip_male_style = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125329.png"
|
||||
)
|
||||
ip_background = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321130643.png"
|
||||
)
|
||||
input_kwargs.update(
|
||||
{
|
||||
"ip_adapter_image": [ip_composition_image, [ip_female_style, ip_male_style, ip_background]],
|
||||
"cross_attention_kwargs": {
|
||||
"ip_adapter_masks": [[composition_mask], [female_mask, male_mask, background_mask]]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return input_kwargs
|
||||
|
||||
|
||||
@@ -575,6 +611,48 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_instant_style_multiple_masks(self):
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
pipeline.load_ip_adapter(
|
||||
["ostris/ip-composition-adapter", "h94/IP-Adapter"],
|
||||
subfolder=["", "sdxl_models"],
|
||||
weight_name=[
|
||||
"ip_plus_composition_sdxl.safetensors",
|
||||
"ip-adapter_sdxl_vit-h.safetensors",
|
||||
],
|
||||
image_encoder_folder=None,
|
||||
)
|
||||
scale_1 = {
|
||||
"down": [[0.0, 0.0, 1.0]],
|
||||
"mid": [[0.0, 0.0, 1.0]],
|
||||
"up": {"block_0": [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], "block_1": [[0.0, 0.0, 1.0]]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale([1.0, scale_1])
|
||||
|
||||
inputs = self.get_dummy_inputs(for_instant_style=True)
|
||||
processor = IPAdapterMaskProcessor()
|
||||
masks1 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][0]
|
||||
masks2 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][1]
|
||||
masks1 = processor.preprocess(masks1, height=1024, width=1024)
|
||||
masks2 = processor.preprocess(masks2, height=1024, width=1024)
|
||||
masks2 = masks2.reshape(1, masks2.shape[0], masks2.shape[2], masks2.shape[3])
|
||||
inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks1, masks2]
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
expected_slice = np.array(
|
||||
[0.23551631, 0.20476806, 0.14099443, 0.0, 0.07675594, 0.05672678, 0.0, 0.0, 0.02099729]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_ip_adapter_multiple_masks_one_adapter(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user