diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 49cefcd8a1..8ecde415dc 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -34,6 +34,104 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name CACHE_T = 2 +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + class WanCausalConv3d(nn.Conv3d): r""" A custom 3D causal convolution layer with feature caching support. @@ -134,19 +232,23 @@ class WanResample(nn.Module): - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. """ - def __init__(self, dim: int, mode: str) -> None: + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: super().__init__() self.dim = dim self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 # layers if mode == "upsample2d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) ) elif mode == "upsample3d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) ) self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -363,6 +465,48 @@ class WanMidBlock(nn.Module): return x +class WanResidualDownBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x, feat_cache, feat_idx) + if self.downsampler is not None: + x = self.downsampler(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + class WanEncoder3d(nn.Module): r""" A 3D encoder module. @@ -380,6 +524,7 @@ class WanEncoder3d(nn.Module): def __init__( self, + in_channels: int = 3, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], @@ -388,6 +533,7 @@ class WanEncoder3d(nn.Module): temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock ): super().__init__() self.dim = dim @@ -403,23 +549,35 @@ class WanEncoder3d(nn.Module): scale = 1.0 # init block - self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1) + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) # downsample blocks self.down_blocks = nn.ModuleList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks - for _ in range(num_res_blocks): - self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - self.down_blocks.append(WanAttentionBlock(out_dim)) - in_dim = out_dim + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim - # downsample block - if i != len(dim_mult) - 1: - mode = "downsample3d" if temperal_downsample[i] else "downsample2d" - self.down_blocks.append(WanResample(out_dim, mode=mode)) - scale /= 2.0 + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 # middle blocks self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) @@ -469,6 +627,92 @@ class WanEncoder3d(nn.Module): x = self.conv_out(x) return x +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim) + else: + self.upsampler = None + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + x_copy = x.clone() + + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsampler is not None: + if feat_cache is not None: + x = self.upsampler(x, feat_cache, feat_idx) + else: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) + + return x class WanUpBlock(nn.Module): """ @@ -513,7 +757,7 @@ class WanUpBlock(nn.Module): self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): """ Forward pass through the upsampling block. @@ -564,6 +808,8 @@ class WanDecoder3d(nn.Module): temperal_upsample=[False, True, True], dropout=0.0, non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, ): super().__init__() self.dim = dim @@ -577,7 +823,6 @@ class WanDecoder3d(nn.Module): # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) @@ -589,36 +834,47 @@ class WanDecoder3d(nn.Module): self.up_blocks = nn.ModuleList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks - if i > 0: + if i > 0 and not is_residual: + # wan vae 2.1 in_dim = in_dim // 2 - # Determine if we need upsampling + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None upsample_mode = None - if i != len(dim_mult) - 1: - upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" # Create and add the upsampling block - up_block = WanUpBlock( - in_dim=in_dim, - out_dim=out_dim, - num_res_blocks=num_res_blocks, - dropout=dropout, - upsample_mode=upsample_mode, - non_linearity=non_linearity, - ) + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag= up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) self.up_blocks.append(up_block) - # Update scale for next iteration - if upsample_mode is not None: - scale *= 2.0 - # output blocks self.norm_out = WanRMS_norm(out_dim, images=False) - self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): ## conv1 if feat_cache is not None: idx = feat_idx[0] @@ -637,7 +893,7 @@ class WanDecoder3d(nn.Module): ## upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx) + x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk) ## head x = self.norm_out(x) @@ -656,6 +912,44 @@ class WanDecoder3d(nn.Module): return x +# YiYi TODO: refactor this +from einops import rearrange + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. @@ -671,6 +965,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): def __init__( self, base_dim: int = 96, + decoder_base_dim: Optional[int] = None, z_dim: int = 16, dim_mult: Tuple[int] = [1, 2, 4, 4], num_res_blocks: int = 2, @@ -713,6 +1008,10 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): 2.8251, 1.9160, ], + is_residual: bool = False, + in_channels: int = 3, + out_channels: int = 3, + patch_size: Optional[int] = None, ) -> None: super().__init__() @@ -720,14 +1019,17 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.temperal_downsample = temperal_downsample self.temperal_upsample = temperal_downsample[::-1] + if decoder_base_dim is None: + decoder_base_dim = base_dim + self.encoder = WanEncoder3d( - base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual ) self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) self.decoder = WanDecoder3d( - base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual ) self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) @@ -827,6 +1129,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): return self.tiled_encode(x) self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) iter_ = 1 + (num_frame - 1) // 4 for i in range(iter_): self._enc_conv_idx = [0] @@ -884,12 +1188,14 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) else: out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) out = torch.clamp(out, min=-1.0, max=1.0) + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) self.clear_cache() if not return_dict: return (out,) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index d14dac91f1..748b20c112 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -112,10 +112,21 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. + If provided, enables two-stage denoising where `transformer` handles high-noise stages + and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. + When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. + If `None`, only `transformer` is used for the entire denoising process. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer_2"] + def __init__( self, @@ -124,6 +135,8 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, ): super().__init__() @@ -133,8 +146,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, + transformer_2=transformer_2, ) - + self.register_to_config(boundary_ratio=boundary_ratio) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -270,6 +284,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -301,6 +316,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") def prepare_latents( self, @@ -369,6 +387,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -407,6 +426,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, + uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -461,6 +483,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + guidance_scale_2, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -470,7 +493,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -524,34 +551,49 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( + #with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + #with current_model.cache_context("uncond"): + noise_uncond = current_model( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - - if self.do_classifier_free_guidance: - with self.transformer.cache_context("uncond"): - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]