diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index e6a09e0d98..599c90be57 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -344,7 +344,7 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: return config, RENAME_DICT, SPECIAL_KEYS_REMAP -def convert_transformer(model_type: str, stage: str=None): +def convert_transformer(model_type: str, stage: str = None): config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] @@ -580,115 +580,116 @@ def convert_vae(): vae.load_state_dict(new_state_dict, strict=True, assign=True) return vae + vae22_diffusers_config = { - "base_dim": 160, - "z_dim": 48, - "is_residual": True, - "in_channels": 12, - "out_channels": 12, - "decoder_base_dim": 256, - "scale_factor_temporal": 4, - "scale_factor_spatial": 16, - "patch_size": 2, - "latents_mean":[ - -0.2289, - -0.0052, - -0.1323, - -0.2339, - -0.2799, - 0.0174, - 0.1838, - 0.1557, - -0.1382, - 0.0542, - 0.2813, - 0.0891, - 0.1570, - -0.0098, - 0.0375, - -0.1825, - -0.2246, - -0.1207, - -0.0698, - 0.5109, - 0.2665, - -0.2108, - -0.2158, - 0.2502, - -0.2055, - -0.0322, - 0.1109, - 0.1567, - -0.0729, - 0.0899, - -0.2799, - -0.1230, - -0.0313, - -0.1649, - 0.0117, - 0.0723, - -0.2839, - -0.2083, - -0.0520, - 0.3748, - 0.0152, - 0.1957, - 0.1433, - -0.2944, - 0.3573, - -0.0548, - -0.1681, - -0.0667, + "base_dim": 160, + "z_dim": 48, + "is_residual": True, + "in_channels": 12, + "out_channels": 12, + "decoder_base_dim": 256, + "scale_factor_temporal": 4, + "scale_factor_spatial": 16, + "patch_size": 2, + "latents_mean": [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, ], - "latents_std": [ - 0.4765, - 1.0364, - 0.4514, - 1.1677, - 0.5313, - 0.4990, - 0.4818, - 0.5013, - 0.8158, - 1.0344, - 0.5894, - 1.0901, - 0.6885, - 0.6165, - 0.8454, - 0.4978, - 0.5759, - 0.3523, - 0.7135, - 0.6804, - 0.5833, - 1.4146, - 0.8986, - 0.5659, - 0.7069, - 0.5338, - 0.4889, - 0.4917, - 0.4069, - 0.4999, - 0.6866, - 0.4093, - 0.5709, - 0.6065, - 0.6415, - 0.4944, - 0.5726, - 1.2042, - 0.5458, - 1.6887, - 0.3971, - 1.0600, - 0.3943, - 0.5537, - 0.5444, - 0.4089, - 0.7468, - 0.7744, + "latents_std": [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, ], "clip_output": False, } diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 5ff969c5ee..608de25da5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -35,7 +35,6 @@ CACHE_T = 2 class AvgDown3D(nn.Module): - def __init__( self, in_channels, @@ -89,7 +88,6 @@ class AvgDown3D(nn.Module): class DupUp3D(nn.Module): - def __init__( self, in_channels: int, @@ -129,9 +127,10 @@ class DupUp3D(nn.Module): x.size(6) * self.factor_s, ) if first_chunk: - x = x[:, :, self.factor_t - 1:, :, :] + x = x[:, :, self.factor_t - 1 :, :, :] return x + class WanCausalConv3d(nn.Conv3d): r""" A custom 3D causal convolution layer with feature caching support. @@ -244,11 +243,13 @@ class WanResample(nn.Module): # layers if mode == "upsample2d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), ) elif mode == "upsample3d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), ) self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -466,14 +467,7 @@ class WanMidBlock(nn.Module): class WanResidualDownBlock(nn.Module): - - def __init__(self, - in_dim, - out_dim, - dropout, - num_res_blocks, - temperal_downsample=False, - down_flag=False): + def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False): super().__init__() # Shortcut path with downsample @@ -507,6 +501,7 @@ class WanResidualDownBlock(nn.Module): return x + self.avg_shortcut(x_copy) + class WanEncoder3d(nn.Module): r""" A 3D encoder module. @@ -533,7 +528,7 @@ class WanEncoder3d(nn.Module): temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", - is_residual: bool = False, # wan 2.2 vae use a residual downblock + is_residual: bool = False, # wan 2.2 vae use a residual downblock ): super().__init__() self.dim = dim @@ -564,8 +559,8 @@ class WanEncoder3d(nn.Module): num_res_blocks, temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, down_flag=i != len(dim_mult) - 1, - ) - ) + ) + ) else: for _ in range(num_res_blocks): self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) @@ -627,6 +622,7 @@ class WanEncoder3d(nn.Module): x = self.conv_out(x) return x + class WanResidualUpBlock(nn.Module): """ A block that handles upsampling for the WanVAE decoder. @@ -714,6 +710,7 @@ class WanResidualUpBlock(nn.Module): return x + class WanUpBlock(nn.Module): """ A block that handles upsampling for the WanVAE decoder. @@ -854,7 +851,7 @@ class WanDecoder3d(nn.Module): num_res_blocks=num_res_blocks, dropout=dropout, temperal_upsample=temperal_upsample[i] if up_flag else False, - up_flag= up_flag, + up_flag=up_flag, non_linearity=non_linearity, ) else: @@ -893,7 +890,7 @@ class WanDecoder3d(nn.Module): ## upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk) + x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk) ## head x = self.norm_out(x) @@ -913,20 +910,39 @@ class WanDecoder3d(nn.Module): def patchify(x, patch_size): - # YiYi TODO: refactor this - from einops import rearrange if patch_size == 1: return x + if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + # x shape: [batch_size, channels, height, width] + batch_size, channels, height, width = x.shape + + # Ensure height and width are divisible by patch_size + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + # Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size] + x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size) + + # Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size] + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size) + elif x.dim() == 5: - x = rearrange( - x, - "b c f (h q) (w r) -> b (c r q) f h w", - q=patch_size, - r=patch_size, - ) + # x shape: [batch_size, channels, frames, height, width] + batch_size, channels, frames, height, width = x.shape + + # Ensure height and width are divisible by patch_size + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] + x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) + + # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] + x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) + else: raise ValueError(f"Invalid input shape: {x.shape}") @@ -934,23 +950,36 @@ def patchify(x, patch_size): def unpatchify(x, patch_size): - # YiYi TODO: refactor this - from einops import rearrange if patch_size == 1: return x if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + # x shape: [b, (c * patch_size * patch_size), h, w] + batch_size, c_patches, height, width = x.shape + channels = c_patches // (patch_size * patch_size) + + # Reshape to [b, c, patch_size, patch_size, h, w] + x = x.view(batch_size, channels, patch_size, patch_size, height, width) + + # Rearrange to [b, c, h * patch_size, w * patch_size] + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(batch_size, channels, height * patch_size, width * patch_size) + elif x.dim() == 5: - x = rearrange( - x, - "b (c r q) f h w -> b c f (h q) (w r)", - q=patch_size, - r=patch_size, - ) + # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] + batch_size, c_patches, frames, height, width = x.shape + channels = c_patches // (patch_size * patch_size) + + # Reshape to [b, c, patch_size, patch_size, f, h, w] + x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width) + + # Rearrange to [b, c, f, h * patch_size, w * patch_size] + x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous() + x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size) + return x + class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. @@ -1027,13 +1056,29 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): decoder_base_dim = base_dim self.encoder = WanEncoder3d( - in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual + in_channels=in_channels, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + is_residual=is_residual, ) self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) self.decoder = WanDecoder3d( - dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual + dim=decoder_base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temperal_upsample, + dropout=dropout, + out_channels=out_channels, + is_residual=is_residual, ) self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) @@ -1192,7 +1237,9 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) + out = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True + ) else: out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 7352a8a21a..eddf196718 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -312,7 +312,6 @@ class WanTransformerBlock(nn.Module): temb: torch.Tensor, rotary_emb: torch.Tensor, ) -> torch.Tensor: - if temb.ndim == 4: # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( @@ -490,7 +489,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) if timestep.ndim == 2: ts_seq_len = timestep.shape[1] - timestep = timestep.flatten() # batch_size * seq_len + timestep = timestep.flatten() # batch_size * seq_len else: ts_seq_len = None @@ -518,7 +517,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - if temb.ndim ==3: + if temb.ndim == 3: # batch_size, seq_len, inner_dim (wan 2.2 ti2v) shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) shift = shift.squeeze(2) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 65e4c1c344..f52bf33d81 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -113,21 +113,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer_2 ([`WanTransformer3DModel`], *optional*): - Conditional Transformer to denoise the input latents during the low-noise stage. - If provided, enables two-stage denoising where `transformer` handles high-noise stages - and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. boundary_ratio (`float`, *optional*, defaults to `None`): Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. - The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. - When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. - If `None`, only `transformer` is used for the entire denoising process. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. """ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _optional_components = ["transformer_2"] - def __init__( self, tokenizer: AutoTokenizer, @@ -137,7 +136,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): scheduler: FlowMatchEulerDiscreteScheduler, transformer_2: Optional[WanTransformer3DModel] = None, boundary_ratio: Optional[float] = None, - expand_timesteps: bool = False, # Wan2.2 ti2v + expand_timesteps: bool = False, # Wan2.2 ti2v ): super().__init__() @@ -429,8 +428,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. guidance_scale_2 (`float`, *optional*, defaults to `None`): - Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, - uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -549,7 +549,6 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): latents, ) - mask = torch.ones(latents.shape, dtype=torch.float32, device=device) # 6. Denoising loop diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index c2299a2e46..b075cf5ba0 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -150,14 +150,14 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer_2 ([`WanTransformer3DModel`], *optional*): - Conditional Transformer to denoise the input latents during the low-noise stage. - In two-stage denoising, `transformer` handles high-noise stages - and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only + `transformer` is used. boundary_ratio (`float`, *optional*, defaults to `None`): Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. - The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. - When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. - If `None`, only `transformer` is used for the entire denoising process. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. """ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" @@ -171,9 +171,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, - image_processor: CLIPImageProcessor=None, - image_encoder: CLIPVisionModel=None, - transformer_2: WanTransformer3DModel=None, + image_processor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModel = None, + transformer_2: WanTransformer3DModel = None, boundary_ratio: Optional[float] = None, ): super().__init__() @@ -550,8 +550,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. guidance_scale_2 (`float`, *optional*, defaults to `None`): - Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, - uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -624,7 +625,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - if self.config.boundary_ratio is not None and guidance_scale_2 is None: guidance_scale_2 = guidance_scale @@ -662,7 +662,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if self.config.boundary_ratio is None: if image_embeds is None: if last_image is None: @@ -706,7 +705,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): else: boundary_timestep = None - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -723,7 +721,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): current_model = self.transformer_2 current_guidance_scale = guidance_scale_2 - latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0])