diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 3a1be76708..90deb3fa88 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -1,6 +1,7 @@ name: Fast GPU Tests on main on: + workflow_dispatch: push: branches: - main diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index dd3c75ee12..e006006a33 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -163,3 +163,15 @@ image.save("flux-fp8-dev.png") [[autodoc]] FluxPipeline - all - __call__ + +## FluxImg2ImgPipeline + +[[autodoc]] FluxImg2ImgPipeline + - all + - __call__ + +## FluxInpaintPipeline + +[[autodoc]] FluxInpaintPipeline + - all + - __call__ diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 952d86a1f2..eaa0ebd806 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,8 +8,10 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: +> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] > **Gated model** @@ -100,8 +102,10 @@ accelerate launch train_dreambooth_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-4 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -120,15 +124,23 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. -> [!TIP] -> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. - ## LoRA + DreamBooth [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +### Prodigy Optimizer +Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. +By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). + +to use prodigy, specify +```bash +--optimizer="prodigy" +``` +> [!TIP] +> When using prodigy it's generally good practice to set- `--learning_rate=1.0` + To perform DreamBooth with LoRA, run: ```bash @@ -144,8 +156,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -162,6 +176,7 @@ Alongside the transformer, fine-tuning of the CLIP text encoder is also supporte To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: > [!NOTE] +> This is still an experimental feature. > FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL). By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed. > At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. @@ -180,8 +195,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -191,5 +208,21 @@ accelerate launch train_dreambooth_lora_flux.py \ --push_to_hub ``` +## Memory Optimizations +As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training. +### Image Resolution +An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. +Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. +### Gradient Checkpointing and Accumulation +* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. +by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. +* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. +Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. +### 8-bit-Adam Optimizer +When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. +Make sure to install `bitsandbytes` if you want to do so. +### latent caching +When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. +to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents` ## Other notes -Thanks to `bghira` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2e77cb946f..17e6e107b0 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -15,7 +15,6 @@ import argparse import copy -import gc import itertools import logging import math @@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, + clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, ) @@ -210,9 +210,7 @@ def log_validation( } ) - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory(objs=[pipeline]) return images @@ -1107,9 +1105,7 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory(objs=[pipeline]) # Handle the repository creation if accelerator.is_main_process: @@ -1455,12 +1451,10 @@ def main(args): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - del text_encoder_one, text_encoder_two, text_encoder_three - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory( + objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] + ) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1795,11 +1789,11 @@ def main(args): pipeline_args=pipeline_args, epoch=epoch, ) + objs = [] if not args.train_text_encoder: - del text_encoder_one, text_encoder_two, text_encoder_three + objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) - torch.cuda.empty_cache() - gc.collect() + clear_objs_and_retain_memory(objs=objs) # Save the lora layers accelerator.wait_for_everyone() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bb8ceccb76..af28b383b5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -258,6 +258,8 @@ else: "CogVideoXVideoToVideoPipeline", "CycleDiffusionPipeline", "FluxControlNetPipeline", + "FluxImg2ImgPipeline", + "FluxInpaintPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -703,6 +705,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: CogVideoXVideoToVideoPipeline, CycleDiffusionPipeline, FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 8738ff49fa..d58bd9e3e3 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -569,7 +569,7 @@ class VaeImageProcessor(ConfigMixin): channel = image.shape[1] # don't need any preprocess if the image is latents - if channel == self.vae_latent_channels: + if channel == self.config.vae_latent_channels: return image height, width = self.get_default_height_width(image, height, width) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 4b54269479..f6dea33e8e 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -562,7 +562,8 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): new_key += ".attn.to_out.0" elif "processor.proj_lora2" in old_key: new_key += ".attn.to_add_out" - elif "processor.qkv_lora1" in old_key and "up" not in old_key: + # Handle text latents. + elif "processor.qkv_lora2" in old_key and "up" not in old_key: handle_qkv( old_state_dict, new_state_dict, @@ -574,7 +575,8 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): ], ) # continue - elif "processor.qkv_lora2" in old_key and "up" not in old_key: + # Handle image latents. + elif "processor.qkv_lora1" in old_key and "up" not in old_key: handle_qkv( old_state_dict, new_state_dict, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7766442f71..84db0d0617 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1104,8 +1104,26 @@ class FreeNoiseTransformerBlock(nn.Module): accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights - hidden_states = torch.where( - num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # TODO(aryan): Maybe this could be done in a better way. + # + # Previously, this was: + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ) + # + # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory + # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes + # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly + # looked into this deeply because other memory optimizations led to more pronounced reductions. + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, ).to(dtype) # 3. Feed-forward diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5e9863ab0d..eb5067c377 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -342,15 +342,58 @@ class CogVideoXPatchEmbed(nn.Module): embed_dim: int = 1920, text_embed_dim: int = 4096, bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, ) -> None: super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.text_proj = nn.Linear(text_embed_dim, embed_dim) + if use_positional_embeddings: + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=False) + + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros( + 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) + + return joint_pos_embedding + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): r""" Args: @@ -371,6 +414,21 @@ class CogVideoXPatchEmbed(nn.Module): embeds = torch.cat( [text_embeds, image_embeds], dim=1 ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + + if self.use_positional_embeddings: + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + if ( + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + else: + pos_embedding = self.pos_embedding + + embeds = embeds + pos_embedding + return embeds @@ -550,8 +608,11 @@ def get_1d_rotary_pos_embed( pos = torch.from_numpy(pos) # type: ignore # [S] theta = theta * ntk_factor - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] - freqs = freqs.to(pos.device) + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 753514a42e..12435fa340 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -24,7 +24,7 @@ from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -240,33 +240,29 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): super().__init__() inner_dim = num_attention_heads * attention_head_dim - post_patch_height = sample_height // patch_size - post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 - self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames - # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + ) self.embedding_dropout = nn.Dropout(dropout) - # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, - ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) - self.register_buffer("pos_embedding", pos_embedding, persistent=False) - - # 3. Time embeddings + # 2. Time embeddings self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - # 4. Define spatio-temporal transformers blocks + # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ CogVideoXBlock( @@ -285,7 +281,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ) self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - # 5. Output blocks + # 4. Output blocks self.norm_out = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=2 * inner_dim, @@ -423,20 +419,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) - # 3. Position embedding text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings: - seq_length = height * width * num_frames // (self.config.patch_size**2) - - pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] - hidden_states = hidden_states + pos_embeds - hidden_states = self.embedding_dropout(hidden_states) - encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - # 4. Transformer blocks + # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: @@ -472,11 +461,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): hidden_states = self.norm_final(hidden_states) hidden_states = hidden_states[:, text_seq_length:] - # 5. Final block + # 4. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) - # 6. Unpatchify + # 5. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 9a168bd22c..4f55df32b7 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -463,7 +463,6 @@ class UNet2DConditionModel( dropout=dropout, ) self.up_blocks.append(up_block) - prev_output_channel = output_channel # out if norm_num_groups is not None: @@ -599,7 +598,7 @@ class UNet2DConditionModel( ) elif encoder_hid_dim_type is not None: raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." ) else: self.encoder_hid_proj = None @@ -679,7 +678,9 @@ class UNet2DConditionModel( # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + raise ValueError( + f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." + ) def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): if attention_type in ["gated", "gated-text-image"]: @@ -990,7 +991,7 @@ class UNet2DConditionModel( image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 - style + # Kandinsky 2.2 ControlNet - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" @@ -1009,7 +1010,7 @@ class UNet2DConditionModel( # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embeds = added_cond_kwargs.get("image_embeds") @@ -1018,14 +1019,14 @@ class UNet2DConditionModel( # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: @@ -1140,7 +1141,6 @@ class UNet2DConditionModel( # 1. time t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) if class_emb is not None: diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 89cdb76741..6125feba58 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module): hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(input=hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( - hidden_states, + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, @@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module): ) # 3. Output - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(input=hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) @@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module): ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module): if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, @@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module): **ckpt_kwargs, ) else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a999e0441d..ad7ea2872a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,7 +124,12 @@ else: "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] - _import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"] + _import_structure["flux"] = [ + "FluxControlNetPipeline", + "FluxImg2ImgPipeline", + "FluxInpaintPipeline", + "FluxPipeline", + ] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -494,7 +499,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxControlNetPipeline, FluxPipeline + from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 4daf0e7717..39ceadb5ac 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -29,7 +29,7 @@ from .controlnet import ( StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline -from .flux import FluxPipeline +from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -108,6 +108,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("auraflow", AuraFlowPipeline), ("flux", FluxPipeline), + ("flux-controlnet", FluxControlNetPipeline), ("lumina", LuminaText2ImgPipeline), ] ) @@ -126,6 +127,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), + ("flux", FluxImg2ImgPipeline), ] ) @@ -140,6 +142,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), + ("flux", FluxInpaintPipeline), ] ) @@ -660,12 +663,17 @@ class AutoPipelineForImage2Image(ConfigMixin): config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] + # the `orig_class_name` can be: + # `- *Pipeline` (for regular text-to-image checkpoint) + # `- *Img2ImgPipeline` (for refiner checkpoint) + to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" + if "controlnet" in kwargs: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -952,14 +960,17 @@ class AutoPipelineForInpainting(ConfigMixin): config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] + # The `orig_class_name`` can be: + # `- *InpaintPipeline` (for inpaint-specific checkpoint) + # - or *Pipeline (for regular text-to-image checkpoint) + to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + if "controlnet" in kwargs: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" - orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace) - + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 23dac5abd0..3937e87f63 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -546,7 +546,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ) elif encoder_hid_dim_type is not None: raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'." ) else: self.encoder_hid_proj = None diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 900189102c..e43a7ab753 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -24,6 +24,8 @@ except OptionalDependencyNotAvailable: else: _import_structure["pipeline_flux"] = ["FluxPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] + _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] + _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -33,6 +35,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .pipeline_flux import FluxPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline + from .pipeline_flux_img2img import FluxImg2ImgPipeline + from .pipeline_flux_inpaint import FluxInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py new file mode 100644 index 0000000000..bee4f6ce52 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -0,0 +1,844 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import FluxImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe( + ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py new file mode 100644 index 0000000000..4603367002 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -0,0 +1,1009 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("flux_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index f2763f1c33..dc0071a494 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock +from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..models.transformers.transformer_2d import Transformer2DModel from ..models.unets.unet_motion_model import ( + AnimateDiffTransformer3D, CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, @@ -30,6 +34,114 @@ from ..utils.torch_utils import randn_tensor logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class SplitInferenceModule(nn.Module): + r""" + A wrapper module class that splits inputs along a specified dimension before performing a forward pass. + + This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking + them into smaller chunks, processing each chunk separately, and then reassembling the results. + + Args: + module (`nn.Module`): + The underlying PyTorch module that will be applied to each chunk of split inputs. + split_size (`int`, defaults to `1`): + The size of each chunk after splitting the input tensor. + split_dim (`int`, defaults to `0`): + The dimension along which the input tensors are split. + input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`): + A list of keyword arguments (strings) that represent the input tensors to be split. + + Workflow: + 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using + `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. + 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments + that were passed. + 3. The output tensors from each split are concatenated back together along `split_dim` before returning. + + Example: + ```python + >>> import torch + >>> import torch.nn as nn + + >>> model = nn.Linear(1000, 1000) + >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) + + >>> input_tensor = torch.randn(42, 1000) + >>> # Will split the tensor into 21 slices of shape [2, 1000]. + >>> output = split_module(input=input_tensor) + ``` + + It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex + multi-dimensional splitting. + """ + + def __init__( + self, + module: nn.Module, + split_size: int = 1, + split_dim: int = 0, + input_kwargs_to_split: List[str] = ["hidden_states"], + ) -> None: + super().__init__() + + self.module = module + self.split_size = split_size + self.split_dim = split_dim + self.input_kwargs_to_split = set(input_kwargs_to_split) + + def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + r"""Forward method for the `SplitInferenceModule`. + + This method processes the input by splitting specified keyword arguments along a given dimension, running the + underlying module on each split, and then concatenating the results. The splitting is controlled by the + `split_size` and `split_dim` parameters specified during initialization. + + Args: + *args (`Any`): + Positional arguments that are passed directly to the `module` without modification. + **kwargs (`Dict[str, torch.Tensor]`): + Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the + entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword + arguments are passed unchanged. + + Returns: + `Union[torch.Tensor, Tuple[torch.Tensor]]`: + The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred + without it. + - If the underlying module returns a single tensor, the result will be a single concatenated tensor + along the same `split_dim` after processing all splits. + - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated + along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. + """ + split_inputs = {} + + # 1. Split inputs that were specified during initialization and also present in passed kwargs + for key in list(kwargs.keys()): + if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): + continue + split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) + kwargs.pop(key) + + # 2. Invoke forward pass across each split + results = [] + for split_input in zip(*split_inputs.values()): + inputs = dict(zip(split_inputs.keys(), split_input)) + inputs.update(kwargs) + + intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) + results.append(intermediate_tensor_or_tensor_tuple) + + # 3. Concatenate split restuls to obtain final outputs + if isinstance(results[0], torch.Tensor): + return torch.cat(results, dim=self.split_dim) + elif isinstance(results[0], tuple): + return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) + else: + raise ValueError( + "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + ) + + class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" @@ -70,6 +182,9 @@ class AnimateDiffFreeNoiseMixin: motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim + ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -98,6 +213,9 @@ class AnimateDiffFreeNoiseMixin: motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim + ) def _check_inputs_free_noise( self, @@ -410,6 +528,69 @@ class AnimateDiffFreeNoiseMixin: for block in blocks: self._disable_free_noise_in_block(block) + def _enable_split_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int + ) -> None: + for motion_module in motion_modules: + motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) + + for i in range(len(motion_module.transformer_blocks)): + motion_module.transformer_blocks[i] = SplitInferenceModule( + motion_module.transformer_blocks[i], + spatial_split_size, + 0, + ["hidden_states", "encoder_hidden_states"], + ) + + motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) + + def _enable_split_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_split_size: int + ) -> None: + for i in range(len(attentions)): + attentions[i] = SplitInferenceModule( + attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] + ) + + def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: + for i in range(len(resnets)): + resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) + + def _enable_split_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int + ) -> None: + for i in range(len(samplers)): + samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) + + def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: + r""" + Enable FreeNoise memory optimizations by utilizing + [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. + + Args: + spatial_split_size (`int`, defaults to `256`): + The split size across spatial dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion + modeling blocks. + temporal_split_size (`int`, defaults to `16`): + The split size across temporal dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial + attention, resnets, downsampling and upsampling blocks. + """ + # TODO(aryan): Discuss on what's the best way to provide more control to users + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + if getattr(block, "motion_modules", None) is not None: + self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) + if getattr(block, "attentions", None) is not None: + self._enable_split_inference_attentions_(block.attentions, temporal_split_size) + if getattr(block, "resnets", None) is not None: + self._enable_split_inference_resnets_(block.resnets, temporal_split_size) + if getattr(block, "downsamplers", None) is not None: + self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) + if getattr(block, "upsamplers", None) is not None: + self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) + @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index d5dedae165..440b6529c9 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -25,7 +25,7 @@ from transformers import ( from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import SD3LoraLoaderMixin +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -148,7 +148,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3InpaintPipeline(DiffusionPipeline): +class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index f497fcc613..26d4a2a504 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,5 +1,6 @@ import contextlib import copy +import gc import math import random from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting +def clear_objs_and_retain_memory(objs: List[Any]): + """Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator.""" + if len(objs) >= 1: + for obj in objs: + del obj + + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() + elif is_torch_npu_available(): + torch_npu.empty_cache() + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 644a148a8b..ff1f38d731 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -317,6 +317,36 @@ class FluxControlNetPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class FluxImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 0aee4f57c2..0f606a056f 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -157,11 +157,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): self.assertTrue(m.weight.device != torch.device("cpu")) + @slow @require_torch_gpu def test_integration_move_lora_dora_cpu(self): from peft import LoraConfig - path = "runwayml/stable-diffusion-v1-5" + path = "Lykon/dreamshaper-8" unet_lora_config = LoraConfig( init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 38cbd788a9..1b1c9b3521 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -528,6 +528,10 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa def test_forward_with_norm_groups(self): pass + @unittest.skip("No attention module used in this model") + def test_set_attn_processor_for_determinism(self): + return + @slow class AutoencoderTinyIntegrationTests(unittest.TestCase): @@ -1032,9 +1036,9 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/img2img/sketch-mountains-input.jpg" ).resize((256, 256)) - image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[ - None, :, :, : - ].cuda() + image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :].to( + torch_device + ) latent = vae.encode(image).latent_dist.mean @@ -1075,7 +1079,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): image = ( torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :] .half() - .cuda() + .to(torch_device) ) latent = vae.encode(image).latent_dist.mean diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index b5a5bec471..66e142f8c6 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -55,17 +55,6 @@ class EmbeddingsTests(unittest.TestCase): assert grad > prev_grad prev_grad = grad - def test_timestep_defaults(self): - embedding_dim = 16 - timesteps = torch.arange(10) - - t1 = get_timestep_embedding(timesteps, embedding_dim) - t2 = get_timestep_embedding( - timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000 - ) - - assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3) - def test_timestep_flip_sin_cos(self): embedding_dim = 16 timesteps = torch.arange(10) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 2437a5a55c..b56ac233ef 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -183,17 +183,6 @@ class ModelUtilsTest(unittest.TestCase): class UNetTesterMixin: - def test_forward_signature(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - signature = inspect.signature(model.forward) - # signature.parameters is an OrderedDict => so arg_names order is deterministic - arg_names = [*signature.parameters.keys()] - - expected_arg_names = ["sample", "timestep"] - self.assertListEqual(arg_names[:2], expected_arg_names) - def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -220,6 +209,7 @@ class ModelTesterMixin: base_precision = 1e-3 forward_requires_fresh_args = False model_split_percents = [0.5, 0.7, 0.9] + uses_custom_attn_processor = False def check_device_map_is_respected(self, model, device_map): for param_name, param in model.named_parameters(): diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 83cdf87baa..6db4113cbd 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -32,6 +32,7 @@ enable_full_determinism() class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = CogVideoXTransformer3DModel main_input_name = "hidden_states" + uses_custom_attn_processor = True @property def dummy_input(self): diff --git a/tests/models/transformers/test_models_transformer_lumina.py b/tests/models/transformers/test_models_transformer_lumina.py index 0b3e666999..6744fb8ac8 100644 --- a/tests/models/transformers/test_models_transformer_lumina.py +++ b/tests/models/transformers/test_models_transformer_lumina.py @@ -32,6 +32,7 @@ enable_full_determinism() class LuminaNextDiT2DModelTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = LuminaNextDiT2DModel main_input_name = "hidden_states" + uses_custom_attn_processor = True @property def dummy_input(self): diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 618a5cff99..54c83d6a1b 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -175,7 +175,7 @@ class AnimateDiffPipelineFastTests( def test_attention_slicing_forward_pass(self): pass - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array( @@ -209,7 +209,7 @@ class AnimateDiffPipelineFastTests( 0.5620, ] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_dict_tuple_outputs_equivalent(self): expected_slice = None @@ -460,6 +460,30 @@ class AnimateDiffPipelineFastTests( "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device) + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index c0ad223c6c..519d848c6d 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -193,7 +193,7 @@ class AnimateDiffControlNetPipelineFastTests( def test_attention_slicing_forward_pass(self): pass - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array( @@ -218,7 +218,7 @@ class AnimateDiffControlNetPipelineFastTests( 0.5155, ] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_dict_tuple_outputs_equivalent(self): expected_slice = None diff --git a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py index e4cc06e1e7..189d6765de 100644 --- a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py +++ b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py @@ -195,7 +195,7 @@ class AnimateDiffSparseControlNetPipelineFastTests( def test_attention_slicing_forward_pass(self): pass - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array( @@ -220,7 +220,7 @@ class AnimateDiffSparseControlNetPipelineFastTests( 0.5155, ] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_dict_tuple_outputs_equivalent(self): expected_slice = None diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index c49790e0f2..c3fd4c7373 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -175,7 +175,7 @@ class AnimateDiffVideoToVideoPipelineFastTests( def test_attention_slicing_forward_pass(self): pass - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": @@ -201,7 +201,7 @@ class AnimateDiffVideoToVideoPipelineFastTests( 0.5378, ] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_inference_batch_single_identical( self, @@ -492,6 +492,34 @@ class AnimateDiffVideoToVideoPipelineFastTests( "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_normal["num_inference_steps"] = 2 + inputs_normal["strength"] = 0.5 + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_enable_split_inference["num_inference_steps"] = 2 + inputs_enable_split_inference["strength"] = 0.5 + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/cogvideox/__init__.py b/tests/pipelines/cogvideo/__init__.py similarity index 100% rename from tests/pipelines/cogvideox/__init__.py rename to tests/pipelines/cogvideo/__init__.py diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py similarity index 100% rename from tests/pipelines/cogvideox/test_cogvideox.py rename to tests/pipelines/cogvideo/test_cogvideox.py diff --git a/tests/pipelines/cogvideox/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py similarity index 100% rename from tests/pipelines/cogvideox/test_cogvideox_video2video.py rename to tests/pipelines/cogvideo/test_cogvideox_video2video.py diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index a5d3a09b21..a2afc52094 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -220,11 +220,11 @@ class ControlNetPipelineFastTests( def test_attention_slicing_forward_pass(self): return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.5234, 0.3333, 0.1745, 0.7605, 0.6224, 0.4637, 0.6989, 0.7526, 0.4665]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), @@ -460,11 +460,11 @@ class StableDiffusionMultiControlNetPipelineFastTests( def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.2422, 0.3425, 0.4048, 0.5351, 0.3503, 0.2419, 0.4645, 0.4570, 0.3804]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_save_pretrained_raise_not_implemented_exception(self): components = self.get_dummy_components() @@ -679,11 +679,11 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.5264, 0.3203, 0.1602, 0.8235, 0.6332, 0.4593, 0.7226, 0.7777, 0.4780]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_save_pretrained_raise_not_implemented_exception(self): components = self.get_dummy_components() diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 0b7ae50a21..05a484a3b8 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -173,11 +173,11 @@ class ControlNetImg2ImgPipelineFastTests( def test_attention_slicing_forward_pass(self): return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.7096, 0.5149, 0.3571, 0.5897, 0.4715, 0.4052, 0.6098, 0.6886, 0.4213]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), @@ -371,11 +371,11 @@ class StableDiffusionMultiControlNetPipelineFastTests( def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.5293, 0.7339, 0.6642, 0.3950, 0.5212, 0.5175, 0.7002, 0.5907, 0.5182]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_save_pretrained_raise_not_implemented_exception(self): components = self.get_dummy_components() diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 6ee83cd6c9..c931391ac4 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -190,14 +190,14 @@ class StableDiffusionXLControlNetPipelineFastTests( def test_attention_slicing_forward_pass(self): return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) - def test_ip_adapter_single(self, from_ssd1b=False, expected_pipe_slice=None): + def test_ip_adapter(self, from_ssd1b=False, expected_pipe_slice=None): if not from_ssd1b: expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array( [0.7335, 0.5866, 0.5623, 0.6242, 0.5751, 0.5999, 0.4091, 0.4590, 0.5054] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), @@ -970,12 +970,12 @@ class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNe # make sure that it's equal assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4 - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.7212, 0.5890, 0.5491, 0.6425, 0.5970, 0.6091, 0.4418, 0.4556, 0.5032]) - return super().test_ip_adapter_single(from_ssd1b=True, expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(from_ssd1b=True, expected_pipe_slice=expected_pipe_slice) def test_controlnet_sdxl_lcm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py index 99ea395ad3..6a5976bd0d 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py @@ -175,12 +175,12 @@ class ControlNetPipelineSDXLImg2ImgFastTests( return inputs - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.6276, 0.5271, 0.5205, 0.5393, 0.5774, 0.5872, 0.5456, 0.5415, 0.5354]) # TODO: update after slices.p - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_stable_diffusion_xl_controlnet_img2img(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py new file mode 100644 index 0000000000..ec89f05382 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -0,0 +1,149 @@ +import random +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") +class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxImg2ImgPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "strength": 0.8, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py new file mode 100644 index 0000000000..7ad77cb6ea --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py @@ -0,0 +1,151 @@ +import random +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") +class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxInpaintPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=8, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=2, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = torch.ones((1, 1, 32, 32)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "strength": 0.8, + "output_type": "np", + } + return inputs + + def test_flux_inpaint_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_inpaint_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 3716920abe..694a4d4574 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -550,7 +550,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 - def test_ip_adapter_single_mask(self): + def test_ip_adapter_mask(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py index 7ae5a8dd81..b60a4553cd 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py @@ -108,11 +108,11 @@ class LatentConsistencyModelPipelineFastTests( } return inputs - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.1403, 0.5072, 0.5316, 0.1202, 0.3865, 0.4211, 0.5363, 0.3557, 0.3645]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_lcm_onestep(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py index 539a8dbb82..386e60c54a 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py @@ -119,11 +119,11 @@ class LatentConsistencyModelImg2ImgPipelineFastTests( } return inputs - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.4003, 0.3718, 0.2863, 0.5500, 0.5587, 0.3772, 0.4617, 0.4961, 0.4417]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_lcm_onestep(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 6854fb8b9a..7efe8002d1 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -175,7 +175,7 @@ class AnimateDiffPAGPipelineFastTests( def test_attention_slicing_forward_pass(self): pass - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": @@ -210,7 +210,7 @@ class AnimateDiffPAGPipelineFastTests( 0.5538, ] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_dict_tuple_outputs_equivalent(self): expected_slice = None diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index 83f550f30b..ca558fbb83 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -176,7 +176,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr assert isinstance(pipe.unet, UNetMotionModel) - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": @@ -211,7 +211,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr 0.5538, ] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_dict_tuple_outputs_equivalent(self): expected_slice = None diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index ec08a755e4..64ebf641d1 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -253,11 +253,11 @@ class StableDiffusionImg2ImgPipelineFastTests( assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.4932, 0.5092, 0.5135, 0.5517, 0.5626, 0.6621, 0.6490, 0.5021, 0.5441]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_stable_diffusion_img2img_multiple_init_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 0317e194f9..6965954f7e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -385,14 +385,14 @@ class StableDiffusionInpaintPipelineFastTests( # they should be the same assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) - def test_ip_adapter_single(self, from_simple=False, expected_pipe_slice=None): + def test_ip_adapter(self, from_simple=False, expected_pipe_slice=None): if not from_simple: expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array( [0.4390, 0.5452, 0.3772, 0.5448, 0.6031, 0.4480, 0.5194, 0.4687, 0.4640] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests): @@ -481,11 +481,11 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli } return inputs - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.6345, 0.5395, 0.5611, 0.5403, 0.5830, 0.5855, 0.5193, 0.5443, 0.5211]) - return super().test_ip_adapter_single(from_simple=True, expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(from_simple=True, expected_pipe_slice=expected_pipe_slice) def test_stable_diffusion_inpaint(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 838f996117..9a3a93acd6 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -281,9 +281,6 @@ class StableDiffusionDepth2ImgPipelineFastTests( max_diff = np.abs(output - output_tuple).max() self.assertLess(max_diff, 1e-4) - def test_progress_bar(self): - super().test_progress_bar() - def test_stable_diffusion_depth2img_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 08cf6c1dc3..8550f25804 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -330,12 +330,12 @@ class StableDiffusionXLPipelineFastTests( # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.5388, 0.5452, 0.4694, 0.4583, 0.5253, 0.4832, 0.5288, 0.5035, 0.4766]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 2bc8143fee..2091af9c03 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -290,7 +290,7 @@ class StableDiffusionXLAdapterPipelineFastTests( } return inputs - def test_ip_adapter_single(self, from_multi=False, expected_pipe_slice=None): + def test_ip_adapter(self, from_multi=False, expected_pipe_slice=None): if not from_multi: expected_pipe_slice = None if torch_device == "cpu": @@ -298,7 +298,7 @@ class StableDiffusionXLAdapterPipelineFastTests( [0.5752, 0.6155, 0.4826, 0.5111, 0.5741, 0.4678, 0.5199, 0.5231, 0.4794] ) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_stable_diffusion_adapter_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -448,12 +448,12 @@ class StableDiffusionXLMultiAdapterPipelineFastTests( expected_slice = np.array([0.5617, 0.6081, 0.4807, 0.5071, 0.5665, 0.4614, 0.5165, 0.5164, 0.4786]) assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.5617, 0.6081, 0.4807, 0.5071, 0.5665, 0.4614, 0.5165, 0.5164, 0.4786]) - return super().test_ip_adapter_single(from_multi=True, expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(from_multi=True, expected_pipe_slice=expected_pipe_slice) def test_inference_batch_consistent( self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"] diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index b160eb41b7..db0905a483 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -310,12 +310,12 @@ class StableDiffusionXLImg2ImgPipelineFastTests( # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.5133, 0.4626, 0.4970, 0.6273, 0.5160, 0.6891, 0.6639, 0.5892, 0.5709]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_stable_diffusion_xl_img2img_tiny_autoencoder(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 089e478836..964c7123dd 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -223,12 +223,12 @@ class StableDiffusionXLInpaintPipelineFastTests( } return inputs - def test_ip_adapter_single(self): + def test_ip_adapter(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.8274, 0.5538, 0.6141, 0.5843, 0.6865, 0.7082, 0.5861, 0.6123, 0.5344]) - return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) + return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice) def test_components_function(self): init_components = self.get_dummy_components() diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 57194acdcf..697244dcb1 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -1,6 +1,25 @@ +import contextlib +import io +import re import unittest +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AnimateDiffPipeline, + AnimateDiffVideoToVideoPipeline, + AutoencoderKL, + DDIMScheduler, + MotionAdapter, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) from diffusers.pipelines.pipeline_utils import is_safetensors_compatible +from diffusers.utils.testing_utils import torch_device class IsSafetensorsCompatibleTests(unittest.TestCase): @@ -177,3 +196,251 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): "unet/diffusion_pytorch_model.fp16.safetensors", ] self.assertTrue(is_safetensors_compatible(filenames)) + + +class ProgressBarTests(unittest.TestCase): + def get_dummy_components_image_generation(self): + cross_attention_dim = 8 + + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=1, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=cross_attention_dim, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[4, 8], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=cross_attention_dim, + intermediate_size=16, + layer_norm_eps=1e-05, + num_attention_heads=2, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_components_video_generation(self): + cross_attention_dim = 8 + block_out_channels = (8, 8) + + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=block_out_channels, + layers_per_block=2, + sample_size=8, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=cross_attention_dim, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + clip_sample=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=block_out_channels, + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=cross_attention_dim, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + torch.manual_seed(0) + motion_adapter = MotionAdapter( + block_out_channels=block_out_channels, + motion_layers_per_block=2, + motion_norm_num_groups=2, + motion_num_attention_heads=4, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "motion_adapter": motion_adapter, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def test_text_to_image(self): + components = self.get_dummy_components_image_generation() + pipe = StableDiffusionPipeline(**components) + pipe.to(torch_device) + + inputs = {"prompt": "a cute cat", "num_inference_steps": 2} + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + stderr = stderr.getvalue() + # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img, + # so we just match "5" in "#####| 1/5 [00:01<00:00]" + max_steps = re.search("/(.*?) ", stderr).group(1) + self.assertTrue(max_steps is not None and len(max_steps) > 0) + self.assertTrue( + f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step" + ) + + pipe.set_progress_bar_config(disable=True) + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") + + def test_image_to_image(self): + components = self.get_dummy_components_image_generation() + pipe = StableDiffusionImg2ImgPipeline(**components) + pipe.to(torch_device) + + image = Image.new("RGB", (32, 32)) + inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "strength": 0.5, "image": image} + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + stderr = stderr.getvalue() + # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img, + # so we just match "5" in "#####| 1/5 [00:01<00:00]" + max_steps = re.search("/(.*?) ", stderr).group(1) + self.assertTrue(max_steps is not None and len(max_steps) > 0) + self.assertTrue( + f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step" + ) + + pipe.set_progress_bar_config(disable=True) + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") + + def test_inpainting(self): + components = self.get_dummy_components_image_generation() + pipe = StableDiffusionInpaintPipeline(**components) + pipe.to(torch_device) + + image = Image.new("RGB", (32, 32)) + mask = Image.new("RGB", (32, 32)) + inputs = { + "prompt": "a cute cat", + "num_inference_steps": 2, + "strength": 0.5, + "image": image, + "mask_image": mask, + } + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + stderr = stderr.getvalue() + # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img, + # so we just match "5" in "#####| 1/5 [00:01<00:00]" + max_steps = re.search("/(.*?) ", stderr).group(1) + self.assertTrue(max_steps is not None and len(max_steps) > 0) + self.assertTrue( + f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step" + ) + + pipe.set_progress_bar_config(disable=True) + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") + + def test_text_to_video(self): + components = self.get_dummy_components_video_generation() + pipe = AnimateDiffPipeline(**components) + pipe.to(torch_device) + + inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "num_frames": 2} + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + stderr = stderr.getvalue() + # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img, + # so we just match "5" in "#####| 1/5 [00:01<00:00]" + max_steps = re.search("/(.*?) ", stderr).group(1) + self.assertTrue(max_steps is not None and len(max_steps) > 0) + self.assertTrue( + f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step" + ) + + pipe.set_progress_bar_config(disable=True) + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") + + def test_video_to_video(self): + components = self.get_dummy_components_video_generation() + pipe = AnimateDiffVideoToVideoPipeline(**components) + pipe.to(torch_device) + + num_frames = 2 + video = [Image.new("RGB", (32, 32))] * num_frames + inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "video": video} + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + stderr = stderr.getvalue() + # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img, + # so we just match "5" in "#####| 1/5 [00:01<00:00]" + max_steps = re.search("/(.*?) ", stderr).group(1) + self.assertTrue(max_steps is not None and len(max_steps) > 0) + self.assertTrue( + f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step" + ) + + pipe.set_progress_bar_config(disable=True) + with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): + _ = pipe(**inputs) + self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py index 768026fa54..d060963f49 100644 --- a/tests/pipelines/test_pipelines_auto.py +++ b/tests/pipelines/test_pipelines_auto.py @@ -235,9 +235,32 @@ class AutoPipelineFastTest(unittest.TestCase): pipe = AutoPipelineForImage2Image.from_pretrained(repo) assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" + controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") + pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline" + pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True) assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline" + + def test_from_pretrained_img2img_refiner(self): + repo = "hf-internal-testing/tiny-stable-diffusion-xl-refiner-pipe" + + pipe = AutoPipelineForImage2Image.from_pretrained(repo) + assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" + + controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") + pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline" + + pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + + pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline" + def test_from_pipe_pag_img2img(self): # test from tableDiffusionXLPAGImg2ImgPipeline pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") @@ -265,6 +288,16 @@ class AutoPipelineFastTest(unittest.TestCase): pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True) assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + def test_from_pretrained_inpaint_from_inpaint(self): + repo = "hf-internal-testing/tiny-stable-diffusion-xl-inpaint-pipe" + + pipe = AutoPipelineForInpainting.from_pretrained(repo) + assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline" + + # make sure you can use pag with inpaint-specific pipeline + pipe = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True) + assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + def test_from_pipe_pag_inpaint(self): # test from tableDiffusionXLPAGInpaintPipeline pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index c3384e6b46..49da08e2ca 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1,10 +1,7 @@ -import contextlib import gc import inspect -import io import json import os -import re import tempfile import unittest import uuid @@ -141,52 +138,35 @@ class SDFunctionTesterMixin: assert np.abs(to_np(output_2) - to_np(output_1)).max() < 5e-1 # test that tiled decode works with various shapes - shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)] + shapes = [(1, 4, 73, 97), (1, 4, 65, 49)] with torch.no_grad(): for shape in shapes: zeros = torch.zeros(shape).to(torch_device) pipe.vae.decode(zeros) - # MPS currently doesn't support ComplexFloats, which are required for freeU - see https://github.com/huggingface/diffusers/issues/7569. + # MPS currently doesn't support ComplexFloats, which are required for FreeU - see https://github.com/huggingface/diffusers/issues/7569. @skip_mps - def test_freeu_enabled(self): + def test_freeu(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # Normal inference inputs = self.get_dummy_inputs(torch_device) inputs["return_dict"] = False inputs["output_type"] = "np" - output = pipe(**inputs)[0] + # FreeU-enabled inference pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) inputs = self.get_dummy_inputs(torch_device) inputs["return_dict"] = False inputs["output_type"] = "np" - output_freeu = pipe(**inputs)[0] - assert not np.allclose( - output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1] - ), "Enabling of FreeU should lead to different results." - - def test_freeu_disabled(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - inputs["return_dict"] = False - inputs["output_type"] = "np" - - output = pipe(**inputs)[0] - - pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) + # FreeU-disabled inference pipe.disable_freeu() - freeu_keys = {"s1", "s2", "b1", "b2"} for upsample_block in pipe.unet.up_blocks: for key in freeu_keys: @@ -195,8 +175,11 @@ class SDFunctionTesterMixin: inputs = self.get_dummy_inputs(torch_device) inputs["return_dict"] = False inputs["output_type"] = "np" - output_no_freeu = pipe(**inputs)[0] + + assert not np.allclose( + output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1] + ), "Enabling of FreeU should lead to different results." assert np.allclose( output, output_no_freeu, atol=1e-2 ), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}." @@ -290,7 +273,15 @@ class IPAdapterTesterMixin: inputs["return_dict"] = False return inputs - def test_ip_adapter_single(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for IP-Adapter. + + The following scenarios are tested: + - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + """ # Raising the tolerance for this test when it's run on a CPU because we # compare against static slices and that can be shaky (with a VVVV low probability). expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff @@ -307,6 +298,7 @@ class IPAdapterTesterMixin: else: output_without_adapter = expected_pipe_slice + # 1. Single IP-Adapter test cases adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) pipe.unet._load_ip_adapter_weights(adapter_state_dict) @@ -338,16 +330,7 @@ class IPAdapterTesterMixin: max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" ) - def test_ip_adapter_multi(self, expected_max_diff: float = 1e-4): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components).to(torch_device) - pipe.set_progress_bar_config(disable=None) - cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) - - # forward pass without ip adapter - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - output_without_adapter = pipe(**inputs)[0] - + # 2. Multi IP-Adapter test cases adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet) pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) @@ -357,12 +340,16 @@ class IPAdapterTesterMixin: inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 pipe.set_ip_adapter_scale([0.0, 0.0]) output_without_multi_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten() # forward pass with multi ip adapter, but with scale of adapter weights inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 pipe.set_ip_adapter_scale([42.0, 42.0]) output_with_multi_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten() max_diff_without_multi_adapter_scale = np.abs( output_without_multi_adapter_scale - output_without_adapter @@ -1689,28 +1676,6 @@ class PipelineTesterMixin: if test_mean_pixel_difference: assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0]) - def test_progress_bar(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(torch_device) - - inputs = self.get_dummy_inputs(torch_device) - with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): - _ = pipe(**inputs) - stderr = stderr.getvalue() - # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img, - # so we just match "5" in "#####| 1/5 [00:01<00:00]" - max_steps = re.search("/(.*?) ", stderr).group(1) - self.assertTrue(max_steps is not None and len(max_steps) > 0) - self.assertTrue( - f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step" - ) - - pipe.set_progress_bar_config(disable=True) - with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): - _ = pipe(**inputs) - self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") - def test_num_images_per_prompt(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py index 033addd51c..bca4fdbfae 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py @@ -173,9 +173,6 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin, def test_num_images_per_prompt(self): pass - def test_progress_bar(self): - return super().test_progress_bar() - @slow @skip_mps diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py index 8ba85455d3..8bef0cede1 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py @@ -13,11 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import gc import inspect -import io -import re import tempfile import unittest @@ -282,28 +279,6 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe def test_pipeline_call_signature(self): pass - def test_progress_bar(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(torch_device) - - inputs = self.get_dummy_inputs(self.generator_device) - with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): - _ = pipe(**inputs) - stderr = stderr.getvalue() - # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img, - # so we just match "5" in "#####| 1/5 [00:01<00:00]" - max_steps = re.search("/(.*?) ", stderr).group(1) - self.assertTrue(max_steps is not None and len(max_steps) > 0) - self.assertTrue( - f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step" - ) - - pipe.set_progress_bar_config(disable=True) - with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): - _ = pipe(**inputs) - self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py index 7f28d12a73..34ccb09e22 100644 --- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py +++ b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py @@ -197,9 +197,6 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def test_num_images_per_prompt(self): pass - def test_progress_bar(self): - return super().test_progress_bar() - @nightly @skip_mps diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index b2bb7fe827..9b89578c5a 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -5,6 +5,7 @@ import requests import torch from huggingface_hub import hf_hub_download, snapshot_download +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, @@ -98,8 +99,8 @@ class SDSingleFileTesterMixin: pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True @@ -138,8 +139,8 @@ class SDSingleFileTesterMixin: upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -191,8 +192,8 @@ class SDSingleFileTesterMixin: pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -286,8 +287,8 @@ class SDXLSingleFileTesterMixin: pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True @@ -327,8 +328,8 @@ class SDXLSingleFileTesterMixin: upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -364,8 +365,8 @@ class SDXLSingleFileTesterMixin: pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py index 1af3f5126f..3e4c1eaaa5 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py @@ -5,6 +5,7 @@ import unittest import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -29,11 +30,11 @@ enable_full_determinism() @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() @@ -108,8 +109,8 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -136,8 +137,9 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) + local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -168,8 +170,9 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) + local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py index 1966ecfc20..d7ccdbd89c 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py @@ -5,6 +5,7 @@ import unittest import torch from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -28,9 +29,9 @@ enable_full_determinism() @require_torch_gpu class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetInpaintPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_INPAINTING.inpainting.safetensors" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" - repo_id = "runwayml/stable-diffusion-inpainting" + repo_id = "Lykon/dreamshaper-8-inpainting" def setUp(self): super().setUp() @@ -83,7 +84,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC output_sf = pipe_sf(**inputs).images[0] max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten()) - assert max_diff < 1e-3 + assert max_diff < 2e-3 def test_single_file_components(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") @@ -103,8 +104,8 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -112,6 +113,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC super()._compare_component_configs(pipe, pipe_single_file) + @unittest.skip("runwayml original config repo does not exist") def test_single_file_components_with_original_config(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16") pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) @@ -121,6 +123,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC super()._compare_component_configs(pipe, pipe_single_file) + @unittest.skip("runwayml original config repo does not exist") def test_single_file_components_with_original_config_local_files_only(self): controlnet = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16" @@ -132,8 +135,8 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -169,8 +172,8 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py index fe066f02cf..4bd7f025f6 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py @@ -5,6 +5,7 @@ import unittest import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -28,11 +29,11 @@ enable_full_determinism() @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() @@ -98,8 +99,8 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, local_files_only=True @@ -126,8 +127,8 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -157,8 +158,8 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py index 1359e66b2c..cbb5e9c3ee 100644 --- a/tests/single_file/test_stable_diffusion_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py @@ -23,11 +23,11 @@ enable_full_determinism() @require_torch_gpu class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionImg2ImgPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py index 3fc7284464..3e133c6ea9 100644 --- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py @@ -23,9 +23,9 @@ enable_full_determinism() @require_torch_gpu class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionInpaintPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_INPAINTING.inpainting.safetensors" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" - repo_id = "runwayml/stable-diffusion-inpainting" + repo_id = "Lykon/dreamshaper-8-inpainting" def setUp(self): super().setUp() @@ -63,11 +63,19 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin def test_single_file_loading_4_channel_unet(self): # Test loading single file inpaint with a 4 channel UNet - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" pipe = self.pipeline_class.from_single_file(ckpt_path) assert pipe.unet.config.in_channels == 4 + @unittest.skip("runwayml original config has been removed") + def test_single_file_components_with_original_config(self): + return + + @unittest.skip("runwayml original config has been removed") + def test_single_file_components_with_original_config_local_files_only(self): + return + @slow @require_torch_gpu diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index 99c884fae0..1283d4d991 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -5,6 +5,7 @@ import unittest import torch from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils.testing_utils import ( enable_full_determinism, require_torch_gpu, @@ -25,11 +26,11 @@ enable_full_determinism() @require_torch_gpu class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() @@ -58,8 +59,8 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile def test_single_file_legacy_scheduler_loading(self): with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py index 7f478133c6..ead77a1d65 100644 --- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py @@ -8,6 +8,7 @@ from diffusers import ( StableDiffusionXLAdapterPipeline, T2IAdapter, ) +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -118,8 +119,8 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = self.pipeline_class.from_single_file( local_ckpt_path, adapter=adapter, safety_checker=None, local_files_only=True @@ -150,8 +151,8 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -188,8 +189,8 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py index a8509510ad..9491adf2df 100644 --- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py @@ -5,6 +5,7 @@ import unittest import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -112,8 +113,8 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -151,8 +152,8 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, @@ -183,8 +184,8 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file(