From de6173c683d72feadc039d30d323427d6265e616 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 3 Nov 2025 09:44:42 -1000 Subject: [PATCH 1/5] [modular]pass hub_kwargs to load_config (#12577) pass hub_kwargs to load_config --- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8076f41c15..307698245e 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -313,7 +313,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ] hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} - config = cls.load_config(pretrained_model_name_or_path) + config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs) has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] trust_remote_code = resolve_trust_remote_code( trust_remote_code, pretrained_model_name_or_path, has_remote_code From 1ec28a2c770546e4483a5109c4e98d6e1b298968 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 4 Nov 2025 05:48:20 +0800 Subject: [PATCH 2/5] ulysses enabling in native attention path (#12563) * ulysses enabling in native attention path Signed-off-by: Wang, Yi A * address review comment Signed-off-by: Wang, Yi A * add supports_context_parallel for native attention Signed-off-by: Wang, Yi A * update templated attention Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A Co-authored-by: Sayak Paul --- src/diffusers/models/attention_dispatch.py | 122 +++++++++++++++++++-- 1 file changed, 110 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ab0d7102ee..c17a3d0ed6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -649,6 +649,86 @@ def _( # ===== Helper functions to use attention backends with templated CP autograd functions ===== +def _native_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # Native attention does not return_lse + if return_lse: + raise ValueError("Native attention does not support return_lse=True") + + # used for backward pass + if _save_ctx: + ctx.save_for_backward(query, key, value) + ctx.attn_mask = attn_mask + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.enable_gqa = enable_gqa + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + return out + + +def _native_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value = ctx.saved_tensors + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query_t, + key=key_t, + value=value_t, + attn_mask=ctx.attn_mask, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + enable_gqa=ctx.enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + grad_out_t = grad_out.permute(0, 2, 1, 3) + grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( + outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False + ) + + grad_query = grad_query_t.permute(0, 2, 1, 3) + grad_key = grad_key_t.permute(0, 2, 1, 3) + grad_value = grad_value_t.permute(0, 2, 1, 3) + + return grad_query, grad_key, grad_value + + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -1523,6 +1603,7 @@ def _native_flex_attention( @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], + supports_context_parallel=True, ) def _native_attention( query: torch.Tensor, @@ -1538,18 +1619,35 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + if _parallel_config is None: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=_native_attention_forward_op, + backward_op=_native_attention_backward_op, + _parallel_config=_parallel_config, + ) + return out From 325a95051bb20787d3db9faba9b4f62b2a63c43a Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 4 Nov 2025 03:38:07 +0300 Subject: [PATCH 3/5] Kandinsky 5.0 Docs fixes (#12582) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add transformer pipeline first version * updates * fix 5sec generation * rewrite Kandinsky5T2VPipeline to diffusers style * add multiprompt support * remove prints in pipeline * add nabla attention * Wrap Transformer in Diffusers style * fix license * fix prompt type * add gradient checkpointing and peft support * add usage example * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: Álvaro Somoza * remove unused imports * add 10 second models support * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * remove no_grad and simplified prompt paddings * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * moved template to __init__ * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * moved sdps inside processor * remove oneline function * remove reset_dtype methods * Transformer: move all methods to forward * separated prompt encoding * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * refactoring * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1 * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * fixed * style +copies * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: Charles * more * Apply suggestions from code review * add lora loader doc * add compiled Nabla Attention * all needed changes for 10 sec models are added! * add docs * Apply style fixes * update docs * add kandinsky5 to toctree * add tests * fix tests * Apply style fixes * update tests * minor docs refactoring * refactor Kandinsky 5.0 Vide docs * Update docs/source/en/_toctree.yml --------- Co-authored-by: Álvaro Somoza Co-authored-by: YiYi Xu Co-authored-by: Charles Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] --- docs/source/en/_toctree.yml | 4 ++-- .../api/pipelines/{kandinsky5.md => kandinsky5_video.md} | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) rename docs/source/en/api/pipelines/{kandinsky5.md => kandinsky5_video.md} (90%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8103c01643..251eb25899 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -529,8 +529,6 @@ title: Kandinsky 2.2 - local: api/pipelines/kandinsky3 title: Kandinsky 3 - - local: api/pipelines/kandinsky5 - title: Kandinsky 5 - local: api/pipelines/kolors title: Kolors - local: api/pipelines/latent_consistency_models @@ -656,6 +654,8 @@ title: Text2Video-Zero - local: api/pipelines/wan title: Wan + - local: api/pipelines/kandinsky5_video + title: Kandinsky 5.0 Video title: Video title: Pipelines - sections: diff --git a/docs/source/en/api/pipelines/kandinsky5.md b/docs/source/en/api/pipelines/kandinsky5_video.md similarity index 90% rename from docs/source/en/api/pipelines/kandinsky5.md rename to docs/source/en/api/pipelines/kandinsky5_video.md index a98a0826b7..533db23e1c 100644 --- a/docs/source/en/api/pipelines/kandinsky5.md +++ b/docs/source/en/api/pipelines/kandinsky5_video.md @@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Kandinsky 5.0 +# Kandinsky 5.0 Video -Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov +Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem. @@ -92,7 +92,7 @@ pipe = pipe.to("cuda") pipe.transformer.set_attention_backend( "flex" -) # <--- Set attention backend to Flex +) # <--- Sett attention bakend to Flex pipe.transformer.compile( mode="max-autotune-no-cudagraphs", dynamic=True @@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9) ``` ### Diffusion Distilled model -**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```): +**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```): ```python model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers" From ac5a1e28fc9cc233863bcfb2abb9eef6807f156f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 4 Nov 2025 10:26:07 +0530 Subject: [PATCH 4/5] [docs] sort doc (#12586) sort doc --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 251eb25899..5af95cba74 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -636,6 +636,8 @@ title: HunyuanVideo - local: api/pipelines/i2vgenxl title: I2VGen-XL + - local: api/pipelines/kandinsky5_video + title: Kandinsky 5.0 Video - local: api/pipelines/latte title: Latte - local: api/pipelines/ltx_video @@ -654,8 +656,6 @@ title: Text2Video-Zero - local: api/pipelines/wan title: Wan - - local: api/pipelines/kandinsky5_video - title: Kandinsky 5.0 Video title: Video title: Pipelines - sections: From dcfb18a2d340d8e1f0ff001b06d2931ffa8648da Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:27:25 +0200 Subject: [PATCH 5/5] [LoRA] add support for more Qwen LoRAs (#12581) * fix bug when offload and cache_latents both enabled * fix --- src/diffusers/loaders/lora_conversion_utils.py | 4 ++++ src/diffusers/loaders/lora_pipeline.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 099dbfc1d2..2807416f97 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2213,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict): state_dict = {convert_key(k): v for k, v in state_dict.items()} + has_default = any("default." in k for k in state_dict) + if has_default: + state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()} + converted_state_dict = {} all_keys = list(state_dict.keys()) down_key = ".lora_down.weight" diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2bb6c0ea02..25919a896a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4940,7 +4940,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict) has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict) has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) - if has_alphas_in_sd or has_lora_unet or has_diffusion_model: + has_default = any("default." in k for k in state_dict) + if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default: state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict) out = (state_dict, metadata) if return_lora_metadata else state_dict