From 41b0c473d2c8da7eef17abf3a1290878ec509f1b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 01:20:53 +0200 Subject: [PATCH] fix controlnet flux --- src/diffusers/models/controlnets/controlnet_flux.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 51c34b7fe9..04ab72e82a 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ) block_samples = block_samples + (hidden_states,) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + single_block_samples = single_block_samples + (hidden_states,) # controlnet block controlnet_block_samples = ()