1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

support flux, ltx i2v, ltx condition

This commit is contained in:
Aryan
2025-04-02 01:21:09 +02:00
parent 41b0c473d2
commit 1f33ca276d
3 changed files with 8 additions and 3 deletions

View File

@@ -906,7 +906,7 @@ class FluxPipeline(
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -917,6 +917,7 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
@@ -932,6 +933,8 @@ class FluxPipeline(
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
cc.mark_state("uncond")
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,

View File

@@ -1061,7 +1061,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
self._num_timesteps = len(timesteps)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -1090,6 +1090,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -771,7 +771,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -783,6 +783,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,