mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Hidream refactoring follow ups (#11299)
* HiDream Image * update * -einops * py3.8 * fix -einops * mixins, offload_seq, option_components * docs * Apply style fixes * trigger tests * Apply suggestions from code review Co-authored-by: Aryan <contact.aryanvs@gmail.com> * joint_attention_kwargs -> attention_kwargs, fixes * fast tests * -_init_weights * style tests * move reshape logic * update slice 😴 * supports_dduf * 🤷🏻♂️ * Update src/diffusers/models/transformers/transformer_hidream_image.py Co-authored-by: Aryan <contact.aryanvs@gmail.com> * address review comments * update tests * doc updates * update * Update src/diffusers/models/transformers/transformer_hidream_image.py * Apply style fixes --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -604,8 +604,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.llama_layers = llama_layers
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
|
||||
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
|
||||
@@ -621,13 +620,13 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
HiDreamBlock(
|
||||
HiDreamImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
)
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -636,42 +635,26 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
HiDreamBlock(
|
||||
HiDreamImageSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
)
|
||||
)
|
||||
for _ in range(self.config.num_single_layers)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
|
||||
|
||||
caption_channels = [
|
||||
caption_channels[1],
|
||||
] * (num_layers + num_single_layers) + [
|
||||
caption_channels[0],
|
||||
]
|
||||
caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
|
||||
caption_projection = []
|
||||
for caption_channel in caption_channels:
|
||||
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
|
||||
|
||||
def expand_timesteps(self, timesteps, batch_size, device):
|
||||
if not torch.is_tensor(timesteps):
|
||||
is_mps = device.type == "mps"
|
||||
if isinstance(timesteps, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
return timesteps
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
if is_training:
|
||||
@@ -773,7 +756,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
hidden_states = out
|
||||
|
||||
# 0. time
|
||||
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
|
||||
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||
p_embedder = self.p_embedder(pooled_embeds)
|
||||
temb = timesteps + p_embedder
|
||||
@@ -793,7 +775,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
T5_encoder_hidden_states = encoder_hidden_states[0]
|
||||
encoder_hidden_states = encoder_hidden_states[-1]
|
||||
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
|
||||
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
new_encoder_hidden_states = []
|
||||
|
||||
Reference in New Issue
Block a user