diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md
index 3464d88145..0ad558fef9 100644
--- a/docs/source/en/api/pipelines/ltx_video.md
+++ b/docs/source/en/api/pipelines/ltx_video.md
@@ -35,6 +35,7 @@ Available models:
| [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
| [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` |
+| [`LTX Video 13B 0.9.7 (distilled)`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) | `torch.bfloat16` |
| [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
@@ -47,6 +48,14 @@ For the best results, it is recommended to follow the guidelines mentioned in th
- For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video.
+
+
+
+
+The examples below show some recommended generation settings, but note that all features supported in the original [LTX Video repository](https://github.com/Lightricks/LTX-Video) are not supported in `diffusers` yet (for example, Spatio-temporal Guidance and CRF compression for image inputs). This will gradually be supported in the future. For the best possible generation quality, we recommend using the code from the original repository.
+
+
+
## Using LTX Video 13B 0.9.7
LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video.
@@ -59,8 +68,8 @@ from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import export_to_video, load_video
-pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16)
-pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16)
+pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16)
+pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/ltxv-spatial-upscaler-0.9.7", vae=pipe.vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe_upsample.to("cuda")
pipe.vae.enable_tiling()
@@ -93,6 +102,11 @@ latents = pipe(
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=30,
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=5.0,
+ guidance_rescale=0.7,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
@@ -117,7 +131,10 @@ video = pipe(
num_inference_steps=10,
latents=upscaled_latents,
decode_timestep=0.05,
- image_cond_noise_scale=0.025,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=5.0,
+ guidance_rescale=0.7,
generator=torch.Generator().manual_seed(0),
output_type="pil",
).frames[0]
@@ -128,6 +145,95 @@ video = [frame.resize((expected_width, expected_height)) for frame in video]
export_to_video(video, "output.mp4", fps=24)
```
+## Using LTX Video 0.9.7 (distilled)
+
+The same example as above can be used with the exception of the `guidance_scale` parameter. The model is both guidance and timestep distilled in order to speedup generation. It requires `guidance_scale` to be set to `1.0`. Additionally, to benefit from the timestep distillation, `num_inference_steps` can be set between `4` and `10` for good generation quality.
+
+Additionally, custom timesteps can also be used for conditioning the generation. The authors recommend using the following timesteps for best results:
+- Base model inference to prepare for upscaling: `[1000, 993, 987, 981, 975, 909, 725, 0.03]`
+- Upscaling: `[1000, 909, 725, 421, 0]`
+
+
+ Full example
+
+```python
+import torch
+from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
+from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
+from diffusers.utils import export_to_video, load_video
+
+pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-distilled", torch_dtype=torch.bfloat16)
+pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/ltxv-spatial-upscaler-0.9.7", vae=pipe.vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+pipe_upsample.to("cuda")
+pipe.vae.enable_tiling()
+
+def round_to_nearest_resolution_acceptable_by_vae(height, width):
+ height = height - (height % pipe.vae_temporal_compression_ratio)
+ width = width - (width % pipe.vae_temporal_compression_ratio)
+ return height, width
+
+prompt = "artistic anatomical 3d render, utlra quality, human half full male body with transparent skin revealing structure instead of organs, muscular, intricate creative patterns, monochromatic with backlighting, lightning mesh, scientific concept art, blending biology with botany, surreal and ethereal quality, unreal engine 5, ray tracing, ultra realistic, 16K UHD, rich details. camera zooms out in a rotating fashion"
+negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+expected_height, expected_width = 768, 1152
+downscale_factor = 2 / 3
+num_frames = 161
+
+# Part 1. Generate video at smaller resolution
+downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
+downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
+latents = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=downscaled_width,
+ height=downscaled_height,
+ num_frames=num_frames,
+ timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=1.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="latent",
+).frames
+
+# Part 2. Upscale generated video using latent upsampler with fewer inference steps
+# The available latent upsampler upscales the height/width by 2x
+upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
+upscaled_latents = pipe_upsample(
+ latents=latents,
+ adain_factor=1.0,
+ output_type="latent"
+).frames
+
+# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=upscaled_width,
+ height=upscaled_height,
+ num_frames=num_frames,
+ denoise_strength=0.999, # Effectively, 4 inference steps out of 5
+ timesteps=[1000, 909, 725, 421, 0],
+ latents=upscaled_latents,
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=1.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="pil",
+).frames[0]
+
+# Part 4. Downscale the video to the expected resolution
+video = [frame.resize((expected_width, expected_height)) for frame in video]
+
+export_to_video(video, "output.mp4", fps=24)
+```
+
+
+
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 606d146fd1..7f669ef50e 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -140,6 +140,33 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
@@ -481,6 +508,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@@ -514,6 +545,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
+ guidance_rescale: float = 0.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -556,6 +588,11 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -624,6 +661,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
)
self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -737,6 +775,12 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index a99294a9e2..4724880658 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -222,6 +222,33 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for text/image/video-to-video generation.
@@ -794,6 +821,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@@ -833,6 +864,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
+ guidance_rescale: float = 0.0,
image_cond_noise_scale: float = 0.15,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -893,6 +925,11 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -967,6 +1004,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
)
self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -1063,9 +1101,11 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
- sigmas = linear_quadratic_schedule(num_inference_steps)
- timesteps = sigmas * 1000
+ if timesteps is None:
+ sigmas = linear_quadratic_schedule(num_inference_steps)
+ timesteps = sigmas * 1000
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ sigmas = self.scheduler.sigmas
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
latent_sigma = None
@@ -1152,6 +1192,12 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
timestep, _ = timestep.chunk(2)
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
denoised_latents = self.scheduler.step(
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
)[0]
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 4162949559..94bb63b4a4 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -159,6 +159,33 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for image-to-video generation.
@@ -542,6 +569,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@@ -576,6 +607,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
+ guidance_rescale: float = 0.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -620,6 +652,11 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -688,6 +725,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
)
self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -811,6 +849,12 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
timestep, _ = timestep.chunk(2)
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
# compute the previous noisy sample x_t -> x_t-1
noise_pred = self._unpack_latents(
noise_pred,
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
index a90e52e717..49cf94e25d 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
@@ -91,6 +91,34 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
+ def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
+ """
+ Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
+ tensor.
+
+ Args:
+ latent (`torch.Tensor`):
+ Input latents to normalize
+ reference_latents (`torch.Tensor`):
+ The reference latents providing style statistics.
+ factor (`float`):
+ Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
+
+ Returns:
+ torch.Tensor: The transformed latent tensor
+ """
+ result = latents.clone()
+
+ for i in range(latents.size(0)):
+ for c in range(latents.size(1)):
+ r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
+ i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
+
+ result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
+
+ result = torch.lerp(latents, result, factor)
+ return result
+
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
@@ -160,6 +188,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
latents: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ adain_factor: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
@@ -204,7 +233,12 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.latent_upsampler.dtype)
- latents = self.latent_upsampler(latents)
+ latents_upsampled = self.latent_upsampler(latents)
+
+ if adain_factor > 0.0:
+ latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
+ else:
+ latents = latents_upsampled
if output_type == "latent":
latents = self._normalize_latents(