mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix a bug for SD35 control net training and improve control net block index (#10065)
* wip --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -393,13 +393,19 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
if self.context_embedder is not None:
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if self.context_embedder is not None:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -424,8 +423,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
# controlnet residual
|
||||
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
|
||||
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
|
||||
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user