1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

addressed PR comments

This commit is contained in:
ishan-modi
2025-03-20 19:14:22 +05:30
parent 7f3cbc595b
commit 4145f6b044
2 changed files with 14 additions and 9 deletions

View File

@@ -246,8 +246,8 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 2. Transformer blocks
block_res_samples = ()
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
@@ -258,7 +258,9 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_height,
post_patch_width,
)
else:
block_res_samples = block_res_samples + (hidden_states,)
else:
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask,
@@ -268,7 +270,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_height,
post_patch_width,
)
block_res_samples = block_res_samples + (hidden_states,)
block_res_samples = block_res_samples + (hidden_states,)
# 3. ControlNet blocks
controlnet_block_res_samples = ()

View File

@@ -434,8 +434,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
# 2. Transformer blocks
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
@@ -446,8 +446,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_height,
post_patch_width,
)
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
else:
else:
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states,
attention_mask,
@@ -457,8 +460,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_height,
post_patch_width,
)
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
# 3. Normalization
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)