mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
addressed PR comments
This commit is contained in:
@@ -246,8 +246,8 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
# 2. Transformer blocks
|
||||
block_res_samples = ()
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
@@ -258,7 +258,9 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
else:
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -268,7 +270,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
|
||||
# 3. ControlNet blocks
|
||||
controlnet_block_res_samples = ()
|
||||
|
||||
@@ -434,8 +434,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
||||
|
||||
# 2. Transformer blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
@@ -446,8 +446,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
else:
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -457,8 +460,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
# 3. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
||||
|
||||
Reference in New Issue
Block a user