mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
updates
This commit is contained in:
@@ -35,6 +35,23 @@ from ..normalization import FP32LayerNorm
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if torch.cuda.get_device_capability()[0] >= 9:
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as FA
|
||||
except:
|
||||
FA = None
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func as FA
|
||||
except:
|
||||
FA = None
|
||||
else:
|
||||
try:
|
||||
from flash_attn import flash_attn_func as FA
|
||||
except:
|
||||
FA = None
|
||||
|
||||
|
||||
# @torch.compile()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
@@ -99,7 +116,7 @@ class VisualEmbeddings(nn.Module):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, duration, height, width, dim = x.shape
|
||||
x = (
|
||||
@@ -107,7 +124,7 @@ class VisualEmbeddings(nn.Module):
|
||||
batch_size,
|
||||
duration // self.patch_size[0],
|
||||
self.patch_size[0],
|
||||
height // self.patch_size[1],
|
||||
height // self.patch_size[1],
|
||||
self.patch_size[1],
|
||||
width // self.patch_size[2],
|
||||
self.patch_size[2],
|
||||
@@ -169,24 +186,23 @@ class RoPE3D(nn.Module):
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
|
||||
batch_size, duration, height, width = shape
|
||||
|
||||
args_t = self.args_0[pos[0]] / scale_factor[0]
|
||||
args_h = self.args_1[pos[1]] / scale_factor[1]
|
||||
args_w = self.args_2[pos[2]] / scale_factor[2]
|
||||
|
||||
# Replicate the original logic with batch dimension
|
||||
args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1)
|
||||
args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1)
|
||||
args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1)
|
||||
|
||||
# Concatenate along the last dimension
|
||||
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F]
|
||||
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1)
|
||||
|
||||
cosine = torch.cos(args)
|
||||
sine = torch.sin(args)
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4]
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2]
|
||||
return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2]
|
||||
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params):
|
||||
@@ -230,11 +246,14 @@ class MultiheadSelfAttentionEnc(nn.Module):
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
# Use torch's scaled_dot_product_attention
|
||||
out = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
).flatten(-2, -1)
|
||||
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE")
|
||||
# out = F.scaled_dot_product_attention(
|
||||
# query.permute(0, 2, 1, 3),
|
||||
# key.permute(0, 2, 1, 3),
|
||||
# value.permute(0, 2, 1, 3),
|
||||
# ).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
|
||||
out = FA(q=query, k=key, v=value).flatten(-2, -1)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
@@ -270,11 +289,15 @@ class MultiheadSelfAttentionDec(nn.Module):
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
# Use standard attention (can be extended with sparse attention)
|
||||
out = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
).flatten(-2, -1)
|
||||
# out = F.scaled_dot_product_attention(
|
||||
# query.permute(0, 2, 1, 3),
|
||||
# key.permute(0, 2, 1, 3),
|
||||
# value.permute(0, 2, 1, 3),
|
||||
# ).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
|
||||
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE")
|
||||
|
||||
out = FA(q=query, k=key, v=value).flatten(-2, -1)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
@@ -306,11 +329,15 @@ class MultiheadCrossAttention(nn.Module):
|
||||
query = self.query_norm(query.float()).type_as(query)
|
||||
key = self.key_norm(key.float()).type_as(key)
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
# out = F.scaled_dot_product_attention(
|
||||
# query.permute(0, 2, 1, 3),
|
||||
# key.permute(0, 2, 1, 3),
|
||||
# value.permute(0, 2, 1, 3),
|
||||
# ).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
|
||||
# print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE")
|
||||
|
||||
out = FA(q=query, k=key, v=value).flatten(-2, -1)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
@@ -339,19 +366,18 @@ class TransformerEncoderBlock(nn.Module):
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim)
|
||||
|
||||
def forward(self, x, time_embed, rope):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
self_attn_params, ff_params = torch.chunk(
|
||||
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
|
||||
out = self.self_attention_norm(x)
|
||||
out = out * (scale + 1.0) + shift
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, rope)
|
||||
x = x + gate * out
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
out = self.feed_forward_norm(x)
|
||||
out = out * (scale + 1.0) + shift
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = x + gate * out
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
@@ -371,26 +397,22 @@ class TransformerDecoderBlock(nn.Module):
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(
|
||||
self.visual_modulation(time_embed), 3, dim=-1
|
||||
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
|
||||
visual_out = self.self_attention_norm(visual_embed)
|
||||
visual_out = visual_out * (scale + 1.0) + shift
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, rope, sparse_params)
|
||||
visual_embed = visual_embed + gate * visual_out
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
visual_out = self.cross_attention_norm(visual_embed)
|
||||
visual_out = visual_out * (scale + 1.0) + shift
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed)
|
||||
visual_embed = visual_embed + gate * visual_out
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
visual_out = self.feed_forward_norm(visual_embed)
|
||||
visual_out = visual_out * (scale + 1.0) + shift
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = visual_embed + gate * visual_out
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
@@ -575,7 +597,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
# 1. Process text embeddings
|
||||
text_embed = self.text_embeddings(encoder_hidden_states)
|
||||
time_embed = self.time_embeddings(timestep)
|
||||
|
||||
|
||||
# Add pooled text embedding to time embedding
|
||||
pooled_embed = self.pooled_text_embeddings(pooled_text_embed)
|
||||
time_embed = time_embed + pooled_embed
|
||||
@@ -587,22 +609,29 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
text_rope = self.text_rope_embeddings(text_rope_pos)
|
||||
|
||||
# 4. Text transformer blocks
|
||||
i = 0
|
||||
for text_block in self.text_transformer_blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
text_embed = torch.utils.checkpoint.checkpoint(
|
||||
text_block, text_embed, time_embed, text_rope, use_reentrant=False
|
||||
)
|
||||
|
||||
else:
|
||||
text_embed = text_block(text_embed, time_embed, text_rope)
|
||||
|
||||
i += 1
|
||||
|
||||
# 5. Prepare visual rope
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
|
||||
|
||||
# visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
|
||||
# visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
|
||||
visual_embed = visual_embed.flatten(1, 3)
|
||||
visual_rope = visual_rope.flatten(1, 3)
|
||||
|
||||
visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
|
||||
visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
|
||||
|
||||
# 6. Visual transformer blocks
|
||||
i = 0
|
||||
for visual_block in self.visual_transformer_blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
visual_embed = torch.utils.checkpoint.checkpoint(
|
||||
@@ -619,6 +648,8 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
visual_embed = visual_block(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
|
||||
i += 1
|
||||
|
||||
# 7. Output projection
|
||||
visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1)
|
||||
|
||||
@@ -220,19 +220,14 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
# Encode with Qwen2.5-VL
|
||||
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt, device, num_videos_per_prompt
|
||||
)
|
||||
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt)
|
||||
pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(
|
||||
negative_prompt, device, num_videos_per_prompt
|
||||
)
|
||||
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt)
|
||||
negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
@@ -264,23 +259,25 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
num_latent_frames = latents.shape[1]
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
num_channels_latents,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
else:
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
num_channels_latents,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if visual_cond:
|
||||
# For visual conditioning, concatenate with zeros and mask
|
||||
@@ -294,50 +291,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
return latents
|
||||
|
||||
def get_velocity(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
text_embeds: Dict[str, torch.Tensor],
|
||||
negative_text_embeds: Optional[Dict[str, torch.Tensor]],
|
||||
visual_rope_pos: List[torch.Tensor],
|
||||
text_rope_pos: torch.Tensor,
|
||||
negative_text_rope_pos: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
sparse_params: Optional[Dict] = None,
|
||||
):
|
||||
# print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params)
|
||||
|
||||
pred_velocity = self.transformer(
|
||||
latents,
|
||||
text_embeds["text_embeds"],
|
||||
text_embeds["pooled_embed"],
|
||||
timestep * 1000, # Scale to match training
|
||||
visual_rope_pos,
|
||||
text_rope_pos,
|
||||
scale_factor=(1, 2, 2), # From Kandinsky config
|
||||
sparse_params=sparse_params,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
if guidance_scale > 1.0 and negative_text_embeds is not None:
|
||||
uncond_pred_velocity = self.transformer(
|
||||
latents,
|
||||
negative_text_embeds["text_embeds"],
|
||||
negative_text_embeds["pooled_embed"],
|
||||
timestep * 1000,
|
||||
visual_rope_pos,
|
||||
negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=sparse_params,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
pred_velocity = uncond_pred_velocity + guidance_scale * (
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
|
||||
return pred_velocity
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
@@ -402,11 +355,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
# 2. Define call parameters
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
@@ -415,16 +366,18 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
device = self._execution_device
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
@@ -433,11 +386,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps (Kandinsky uses custom flow matching)
|
||||
timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device)
|
||||
timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = 16
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
@@ -451,11 +399,12 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
visual_cond = latents[:, :, :, :, 16:]
|
||||
|
||||
# 6. Prepare rope positions
|
||||
visual_rope_pos = [
|
||||
torch.arange(num_frames // 4 + 1, device=device),
|
||||
torch.arange(height // 8 // 2, device=device), # patch size 2
|
||||
torch.arange(height // 8 // 2, device=device),
|
||||
torch.arange(width // 8 // 2, device=device),
|
||||
]
|
||||
|
||||
@@ -467,31 +416,43 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
else None
|
||||
)
|
||||
|
||||
# 7. Prepare sparse attention params if needed
|
||||
sparse_params = None # Can be extended based on Kandinsky attention config
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))):
|
||||
# Expand timestep to match batch size
|
||||
time = timestep.unsqueeze(0)
|
||||
for i, t in enumerate(timesteps):
|
||||
timestep = t.unsqueeze(0)
|
||||
|
||||
pred_velocity = self.get_velocity(
|
||||
latents,
|
||||
time,
|
||||
text_embeds,
|
||||
negative_text_embeds,
|
||||
visual_rope_pos,
|
||||
text_rope_pos,
|
||||
negative_text_rope_pos,
|
||||
guidance_scale,
|
||||
sparse_params,
|
||||
)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
# print(latents.shape)
|
||||
pred_velocity = self.transformer(
|
||||
latents,
|
||||
text_embeds["text_embeds"],
|
||||
text_embeds["pooled_embed"],
|
||||
timestep,
|
||||
visual_rope_pos,
|
||||
text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
if guidance_scale > 1.0 and negative_text_embeds is not None:
|
||||
uncond_pred_velocity = self.transformer(
|
||||
latents,
|
||||
negative_text_embeds["text_embeds"],
|
||||
negative_text_embeds["pooled_embed"],
|
||||
timestep,
|
||||
visual_rope_pos,
|
||||
negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
# Update latents using flow matching
|
||||
latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity
|
||||
pred_velocity = uncond_pred_velocity + guidance_scale * (
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
|
||||
latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
|
||||
latents = torch.cat([latents, visual_cond], dim=-1)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@@ -499,8 +460,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0):
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
latents = latents[:, :, :, :, :16]
|
||||
@@ -524,7 +485,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
video = video / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(video).sample
|
||||
video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
|
||||
|
||||
# Convert to output format
|
||||
if output_type == "pil":
|
||||
if num_frames == 1:
|
||||
@@ -533,6 +493,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
else:
|
||||
# Video frames
|
||||
video = [video[i] for i in range(video.shape[0])]
|
||||
|
||||
else:
|
||||
video = latents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user