mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[fix] multi t2i adapter set total_downscale_factor (#4621)
* [fix] multi t2i adapter set total_downscale_factor * move image checks into check inputs * remove copied from
This commit is contained in:
@@ -41,6 +41,31 @@ class MultiAdapter(ModelMixin):
|
||||
self.num_adapter = len(adapters)
|
||||
self.adapters = nn.ModuleList(adapters)
|
||||
|
||||
if len(adapters) == 0:
|
||||
raise ValueError("Expecting at least one adapter")
|
||||
|
||||
if len(adapters) == 1:
|
||||
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
|
||||
|
||||
# The outputs from each adapter are added together with a weight
|
||||
# This means that the change in dimenstions from downsampling must
|
||||
# be the same for all adapters. Inductively, it also means the total
|
||||
# downscale factor must also be the same for all adapters.
|
||||
|
||||
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
|
||||
|
||||
for idx in range(1, len(adapters)):
|
||||
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor
|
||||
|
||||
if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
|
||||
raise ValueError(
|
||||
f"Expecting all adapters to have the same total_downscale_factor, "
|
||||
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
|
||||
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
|
||||
)
|
||||
|
||||
self.total_downscale_factor = adapters[0].total_downscale_factor
|
||||
|
||||
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -56,14 +81,8 @@ class MultiAdapter(ModelMixin):
|
||||
else:
|
||||
adapter_weights = torch.tensor(adapter_weights)
|
||||
|
||||
if xs.shape[1] % self.num_adapter != 0:
|
||||
raise ValueError(
|
||||
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
|
||||
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
|
||||
)
|
||||
x_list = torch.chunk(xs, self.num_adapter, dim=1)
|
||||
accume_state = None
|
||||
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
|
||||
for x, w, adapter in zip(xs, adapter_weights, self.adapters):
|
||||
features = adapter(x)
|
||||
if accume_state is None:
|
||||
accume_state = features
|
||||
|
||||
@@ -453,13 +453,13 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
image,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -501,6 +501,17 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
if not isinstance(image, list):
|
||||
raise ValueError(
|
||||
"MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`."
|
||||
)
|
||||
|
||||
if len(image) != len(self.adapter.adapters):
|
||||
raise ValueError(
|
||||
f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
@@ -653,17 +664,19 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
is_multi_adapter = isinstance(self.adapter, MultiAdapter)
|
||||
if is_multi_adapter:
|
||||
adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image]
|
||||
n, c, h, w = adapter_input[0].shape
|
||||
adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input])
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_input = []
|
||||
|
||||
for one_image in image:
|
||||
one_image = _preprocess_adapter_image(one_image, height, width)
|
||||
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
|
||||
adapter_input.append(one_image)
|
||||
else:
|
||||
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
|
||||
adapter_input = adapter_input.to(self.adapter.dtype)
|
||||
adapter_input = _preprocess_adapter_image(image, height, width)
|
||||
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
|
||||
@@ -21,19 +21,21 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
MultiAdapter,
|
||||
PNDMScheduler,
|
||||
StableDiffusionAdapterPipeline,
|
||||
T2IAdapter,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, logging, slow, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -82,13 +84,38 @@ class AdapterTests:
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
torch.manual_seed(0)
|
||||
adapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=2,
|
||||
adapter_type=adapter_type,
|
||||
)
|
||||
|
||||
if adapter_type == "full_adapter" or adapter_type == "light_adapter":
|
||||
adapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=2,
|
||||
adapter_type=adapter_type,
|
||||
)
|
||||
elif adapter_type == "multi_adapter":
|
||||
adapter = MultiAdapter(
|
||||
[
|
||||
T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=2,
|
||||
adapter_type="full_adapter",
|
||||
),
|
||||
T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=2,
|
||||
adapter_type="full_adapter",
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter', 'light_adapter', or 'multi_adapter''"
|
||||
)
|
||||
|
||||
components = {
|
||||
"adapter": adapter,
|
||||
@@ -102,8 +129,12 @@ class AdapterTests:
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
def get_dummy_inputs(self, device, seed=0, num_images=1):
|
||||
if num_images == 1:
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
else:
|
||||
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
@@ -172,6 +203,217 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
|
||||
class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
return super().get_dummy_components("multi_adapter")
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
return super().get_dummy_inputs(device, seed, num_images=2)
|
||||
|
||||
def test_stable_diffusion_adapter_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionAdapterPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4902, 0.5539, 0.4317, 0.4682, 0.6190, 0.4351, 0.5018, 0.5046, 0.4772])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
def test_inference_batch_consistent(
|
||||
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
for batch_size in batch_sizes:
|
||||
batched_inputs = {}
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
elif name == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in value:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
batched_inputs[name] = batched_images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
|
||||
batched_inputs.pop("output_type")
|
||||
|
||||
output = pipe(**batched_inputs)
|
||||
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
|
||||
batched_inputs.pop("output_type")
|
||||
|
||||
output = pipe(**batched_inputs)[0]
|
||||
|
||||
assert output.shape[0] == batch_size
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
if key == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in inputs[key]:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
inputs[key] = batched_images
|
||||
else:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=3,
|
||||
test_max_difference=None,
|
||||
test_mean_pixel_difference=None,
|
||||
relax_max_difference=False,
|
||||
expected_max_diff=2e-3,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
if test_max_difference is None:
|
||||
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
|
||||
# make sure that batched and non-batched is identical
|
||||
test_max_difference = torch_device != "mps"
|
||||
|
||||
if test_mean_pixel_difference is None:
|
||||
# TODO same as above
|
||||
test_mean_pixel_difference = torch_device != "mps"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batch_size = batch_size
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
elif name == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in value:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
batched_inputs[name] = batched_images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
elif name == "generator":
|
||||
batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)]
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
if self.pipeline_class.__name__ != "DanceDiffusionPipeline":
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
output_batch = pipe(**batched_inputs)
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
output = pipe(**inputs)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
if test_max_difference:
|
||||
if relax_max_difference:
|
||||
# Taking the median of the largest <n> differences
|
||||
# is resilient to outliers
|
||||
diff = np.abs(output_batch[0][0] - output[0][0])
|
||||
diff = diff.flatten()
|
||||
diff.sort()
|
||||
max_diff = np.median(diff[-5:])
|
||||
else:
|
||||
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
|
||||
|
||||
# We do not support saving pipelines with multiple adapters. The multiple adapters should be saved as their
|
||||
# own independent pipelines
|
||||
|
||||
def test_save_load_local(self):
|
||||
...
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionAdapterPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user