mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Core] refactor transformer_2d forward logic into meaningful conditions. (#7489)
* refactor transformer_2d forward logic into meaningful conditions.
* Empty-Commit
* fix: _operate_on_patched_inputs
* fix: _operate_on_patched_inputs
* check
* fix: patch output computation block.
* fix: _operate_on_patched_inputs.
* remove print.
* move operations to blocks.
* more readability neats.
* empty commit
* Apply suggestions from code review
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* Revert "Apply suggestions from code review"
This reverts commit 12178b1aa0.
---------
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -402,41 +402,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
batch_size, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
|
||||
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.is_input_patches and self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -474,51 +451,116 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
output = self._get_output_for_continuous_inputs(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
inner_dim=inner_dim,
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
elif self.is_input_vectorized:
|
||||
output = self._get_output_for_vectorized_inputs(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
output = self._get_output_for_patched_inputs(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timestep,
|
||||
class_labels=class_labels,
|
||||
embedded_timestep=embedded_timestep,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def _operate_on_continuous_inputs(self, hidden_states):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
return hidden_states, inner_dim
|
||||
|
||||
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
embedded_timestep = None
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if self.caption_projection is not None:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
|
||||
|
||||
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
|
||||
def _get_output_for_vectorized_inputs(self, hidden_states):
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
return output
|
||||
|
||||
def _get_output_for_patched_inputs(
|
||||
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
|
||||
):
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user