mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add a timestep scale for sana-sprint teacher model (#11150)
This commit is contained in:
@@ -326,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for the query and key.
|
||||
timestep_scale (`float`, defaults to `1.0`):
|
||||
The scale to use for the timesteps.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -355,6 +359,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = None,
|
||||
timestep_scale: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -938,6 +938,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
timestep = timestep * self.transformer.config.timestep_scale
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
|
||||
Reference in New Issue
Block a user