From 4145f6b0449e7d6234815343a32f96e11bc94a8d Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Thu, 20 Mar 2025 19:14:22 +0530 Subject: [PATCH] addressed PR comments --- src/diffusers/models/controlnets/controlnet_sana.py | 10 ++++++---- .../models/transformers/sana_transformer.py | 13 ++++++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py index c4b329cd01..7f9d6d9849 100644 --- a/src/diffusers/models/controlnets/controlnet_sana.py +++ b/src/diffusers/models/controlnets/controlnet_sana.py @@ -246,8 +246,8 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # 2. Transformer blocks block_res_samples = () - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: hidden_states = self._gradient_checkpointing_func( block, hidden_states, @@ -258,7 +258,9 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): post_patch_height, post_patch_width, ) - else: + block_res_samples = block_res_samples + (hidden_states,) + else: + for block in self.transformer_blocks: hidden_states = block( hidden_states, attention_mask, @@ -268,7 +270,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): post_patch_height, post_patch_width, ) - block_res_samples = block_res_samples + (hidden_states,) + block_res_samples = block_res_samples + (hidden_states,) # 3. ControlNet blocks controlnet_block_res_samples = () diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 4578d70327..a4c3da9de8 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -434,8 +434,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): encoder_hidden_states = self.caption_norm(encoder_hidden_states) # 2. Transformer blocks - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for index_block, block in enumerate(self.transformer_blocks): hidden_states = self._gradient_checkpointing_func( block, hidden_states, @@ -446,8 +446,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): post_patch_height, post_patch_width, ) + if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): + hidden_states = hidden_states + controlnet_block_samples[index_block - 1] - else: + else: + for index_block, block in enumerate(self.transformer_blocks): hidden_states = block( hidden_states, attention_mask, @@ -457,8 +460,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): post_patch_height, post_patch_width, ) - if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): - hidden_states = hidden_states + controlnet_block_samples[index_block - 1] + if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): + hidden_states = hidden_states + controlnet_block_samples[index_block - 1] # 3. Normalization hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)