diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8103c01643..251eb25899 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -529,8 +529,6 @@ title: Kandinsky 2.2 - local: api/pipelines/kandinsky3 title: Kandinsky 3 - - local: api/pipelines/kandinsky5 - title: Kandinsky 5 - local: api/pipelines/kolors title: Kolors - local: api/pipelines/latent_consistency_models @@ -656,6 +654,8 @@ title: Text2Video-Zero - local: api/pipelines/wan title: Wan + - local: api/pipelines/kandinsky5_video + title: Kandinsky 5.0 Video title: Video title: Pipelines - sections: diff --git a/docs/source/en/api/pipelines/kandinsky5.md b/docs/source/en/api/pipelines/kandinsky5_video.md similarity index 90% rename from docs/source/en/api/pipelines/kandinsky5.md rename to docs/source/en/api/pipelines/kandinsky5_video.md index a98a0826b7..533db23e1c 100644 --- a/docs/source/en/api/pipelines/kandinsky5.md +++ b/docs/source/en/api/pipelines/kandinsky5_video.md @@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Kandinsky 5.0 +# Kandinsky 5.0 Video -Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov +Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem. @@ -92,7 +92,7 @@ pipe = pipe.to("cuda") pipe.transformer.set_attention_backend( "flex" -) # <--- Set attention backend to Flex +) # <--- Sett attention bakend to Flex pipe.transformer.compile( mode="max-autotune-no-cudagraphs", dynamic=True @@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9) ``` ### Diffusion Distilled model -**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```): +**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```): ```python model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers" diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 2a1895d4e9..6ecf97701f 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -640,6 +640,86 @@ def _( # ===== Helper functions to use attention backends with templated CP autograd functions ===== +def _native_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # Native attention does not return_lse + if return_lse: + raise ValueError("Native attention does not support return_lse=True") + + # used for backward pass + if _save_ctx: + ctx.save_for_backward(query, key, value) + ctx.attn_mask = attn_mask + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.enable_gqa = enable_gqa + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + return out + + +def _native_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value = ctx.saved_tensors + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query_t, + key=key_t, + value=value_t, + attn_mask=ctx.attn_mask, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + enable_gqa=ctx.enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + grad_out_t = grad_out.permute(0, 2, 1, 3) + grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( + outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False + ) + + grad_query = grad_query_t.permute(0, 2, 1, 3) + grad_key = grad_key_t.permute(0, 2, 1, 3) + grad_value = grad_value_t.permute(0, 2, 1, 3) + + return grad_query, grad_key, grad_value + + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -1514,6 +1594,7 @@ def _native_flex_attention( @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], + supports_context_parallel=True, ) def _native_attention( query: torch.Tensor, @@ -1529,18 +1610,35 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + if _parallel_config is None: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=_native_attention_forward_op, + backward_op=_native_attention_backward_op, + _parallel_config=_parallel_config, + ) + return out diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 947b610ea6..c96b4fa88c 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -147,14 +147,13 @@ class AutoModel(ConfigMixin): "force_download", "local_files_only", "proxies", - "resume_download", "revision", "token", ] hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names} # load_config_kwargs uses the same hub kwargs minus subfolder and resume_download - load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]} + load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]} library = None orig_class_name = None @@ -205,7 +204,6 @@ class AutoModel(ConfigMixin): module_file=module_file, class_name=class_name, **hub_kwargs, - **kwargs, ) else: from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 9dd8035c44..cb7e8fb736 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -164,7 +164,11 @@ class AutoOffloadStrategy: device_type = execution_device.type device_module = getattr(torch, device_type, torch.cuda) - mem_on_device = device_module.mem_get_info(execution_device.index)[0] + try: + mem_on_device = device_module.mem_get_info(execution_device.index)[0] + except AttributeError: + raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.") + mem_on_device = mem_on_device - self.memory_reserve_margin if current_module_size < mem_on_device: return [] @@ -699,6 +703,8 @@ class ComponentsManager: if not is_accelerate_available(): raise ImportError("Make sure to install accelerate to use auto_cpu_offload") + # TODO: add a warning if mem_get_info isn't available on `device`. + for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): remove_hook_from_module(component, recurse=True) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index c098b7d4f1..daffec9865 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -598,7 +598,7 @@ class FluxKontextRoPEInputsStep(ModularPipelineBlocks): and getattr(block_state, "image_width", None) is not None ): image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) - image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) img_ids = FluxPipeline._prepare_latent_image_ids( None, image_latent_height // 2, image_latent_width // 2, device, dtype ) diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index b1796bb63c..5a769df103 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -59,7 +59,7 @@ class FluxLoopDenoiser(ModularPipelineBlocks): ), InputParam( "guidance", - required=True, + required=False, type_hint=torch.Tensor, description="Guidance scale as a tensor", ), @@ -141,7 +141,7 @@ class FluxKontextLoopDenoiser(ModularPipelineBlocks): ), InputParam( "guidance", - required=True, + required=False, type_hint=torch.Tensor, description="Guidance scale as a tensor", ), diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index b71962bd93..f0314d4771 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -95,7 +95,7 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks): ComponentSpec( "image_processor", VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 16}), + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), default_creation_method="from_config", ), ] @@ -143,10 +143,6 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks): class FluxKontextProcessImagesInputStep(ModularPipelineBlocks): model_name = "flux-kontext" - def __init__(self, _auto_resize=True): - self._auto_resize = _auto_resize - super().__init__() - @property def description(self) -> str: return ( @@ -167,7 +163,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: - return [InputParam("image")] + return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)] @property def intermediate_outputs(self) -> List[OutputParam]: @@ -195,7 +191,8 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks): img = images[0] image_height, image_width = components.image_processor.get_default_height_width(img) aspect_ratio = image_width / image_height - if self._auto_resize: + _auto_resize = block_state._auto_resize + if _auto_resize: # Kontext is trained on specific resolutions, using one of them is recommended _, image_width, image_height = min( (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py index e1bc17f5ff..8309eebfeb 100644 --- a/src/diffusers/modular_pipelines/flux/inputs.py +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -112,6 +112,10 @@ class FluxTextInputStep(ModularPipelineBlocks): block_state.prompt_embeds = block_state.prompt_embeds.view( block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 ) + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + block_state.pooled_prompt_embeds = pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 55c261ab2f..307698245e 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -305,15 +305,15 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): "cache_dir", "force_download", "local_files_only", + "local_dir", "proxies", - "resume_download", "revision", "subfolder", "token", ] hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} - config = cls.load_config(pretrained_model_name_or_path) + config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs) has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] trust_remote_code = resolve_trust_remote_code( trust_remote_code, pretrained_model_name_or_path, has_remote_code @@ -331,7 +331,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): module_file=module_file, class_name=class_name, **hub_kwargs, - **kwargs, ) expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) block_kwargs = { @@ -2131,8 +2130,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): component_load_kwargs[key] = value["default"] try: components_to_register[name] = spec.load(**component_load_kwargs) - except Exception as e: - logger.warning(f"Failed to create component '{name}': {e}") + except Exception: + logger.warning( + f"\nFailed to create component {name}:\n" + f"- Component spec: {spec}\n" + f"- load() called with kwargs: {component_load_kwargs}\n\n" + f"{traceback.format_exc()}" + ) # Register all components at once self.register_components(**components_to_register) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index f67a0e2112..d605eac1f2 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -355,7 +355,7 @@ class StableDiffusion3ControlNetPipeline( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 68984da4dc..9d0158c6b6 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -373,7 +373,7 @@ class StableDiffusion3ControlNetInpaintingPipeline( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index bc281428e2..941b675099 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -326,7 +326,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index 22a8dac238..f40dd52fc2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -342,7 +342,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 3b7b26dc63..660d9801df 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -336,7 +336,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index db047f1992..9b11bc8781 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -361,7 +361,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index c95fa530c8..b947cbff09 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -367,7 +367,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 627b1e0604..b2ef5a29e0 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -254,6 +254,7 @@ def get_cached_module_file( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, + local_dir: Optional[str] = None, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached @@ -332,6 +333,7 @@ def get_cached_module_file( force_download=force_download, proxies=proxies, local_files_only=local_files_only, + local_dir=local_dir, ) submodule = "git" module_file = pretrained_model_name_or_path + ".py" @@ -355,6 +357,7 @@ def get_cached_module_file( force_download=force_download, proxies=proxies, local_files_only=local_files_only, + local_dir=local_dir, token=token, ) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) @@ -415,6 +418,7 @@ def get_cached_module_file( token=token, revision=revision, local_files_only=local_files_only, + local_dir=local_dir, ) return os.path.join(full_submodule, module_file) @@ -431,7 +435,7 @@ def get_class_from_dynamic_module( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, - **kwargs, + local_dir: Optional[str] = None, ): """ Extracts a class from a module file, present in the local folder or repository of a model. @@ -496,5 +500,6 @@ def get_class_from_dynamic_module( token=token, revision=revision, local_files_only=local_files_only, + local_dir=local_dir, ) return get_class_in_module(class_name, final_module) diff --git a/tests/modular_pipelines/flux/__init__.py b/tests/modular_pipelines/flux/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py new file mode 100644 index 0000000000..9d70c21aa8 --- /dev/null +++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import tempfile +import unittest + +import numpy as np +import PIL +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.modular_pipelines import ( + FluxAutoBlocks, + FluxKontextAutoBlocks, + FluxKontextModularPipeline, + FluxModularPipeline, + ModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class FluxModularTests: + pipeline_class = FluxModularPipeline + pipeline_blocks_class = FluxAutoBlocks + repo = "hf-internal-testing/tiny-flux-modular" + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) + pipeline.load_components(torch_dtype=torch_dtype) + return pipeline + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + +class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase): + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + + +class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase): + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = super().get_pipeline(components_manager, torch_dtype) + # Override `vae_scale_factor` here as currently, `image_processor` is initialized with + # fixed constants instead of + # https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10 + pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2) + return pipeline + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + inputs["image"] = image + inputs["strength"] = 0.8 + inputs["height"] = 8 + inputs["width"] = 8 + return inputs + + def test_save_from_pretrained(self): + pipes = [] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + with tempfile.TemporaryDirectory() as tmpdirname: + base_pipe.save_pretrained(tmpdirname) + pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipe.image_processor = VaeImageProcessor(vae_scale_factor=2) + + pipes.append(pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + +class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests): + pipeline_class = FluxKontextModularPipeline + pipeline_blocks_class = FluxKontextAutoBlocks + repo = "hf-internal-testing/tiny-flux-kontext-pipe" + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = PIL.Image.new("RGB", (32, 32), 0) + _ = inputs.pop("strength") + inputs["image"] = image + inputs["height"] = 8 + inputs["width"] = 8 + inputs["max_area"] = 8 * 8 + inputs["_auto_resize"] = False + return inputs diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index d05f818135..22347aa558 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -21,24 +21,12 @@ import numpy as np import torch from PIL import Image -from diffusers import ( - ClassifierFreeGuidance, - StableDiffusionXLAutoBlocks, - StableDiffusionXLModularPipeline, -) +from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from diffusers.loaders import ModularIPAdapterMixin -from ...models.unets.test_models_unet_2d_condition import ( - create_ip_adapter_state_dict, -) -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) -from ..test_modular_pipelines_common import ( - ModularPipelineTesterMixin, -) +from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin enable_full_determinism()