mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -422,7 +422,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if video is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
base = self.vae_scale_factor_spatial * (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
video_height, video_width = self.video_processor.get_default_height_width(video[0])
|
||||
|
||||
if video_height * video_width > height * width:
|
||||
@@ -597,7 +601,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
transformer_patch_size = self.transformer.config.patch_size[1]
|
||||
transformer_patch_size = (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
|
||||
mask_list = []
|
||||
for mask_, reference_images_batch in zip(mask, reference_images):
|
||||
@@ -852,20 +860,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
|
||||
vace_layers = (
|
||||
self.transformer.config.vace_layers
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.vace_layers
|
||||
)
|
||||
if isinstance(conditioning_scale, (int, float)):
|
||||
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
|
||||
conditioning_scale = [conditioning_scale] * len(vace_layers)
|
||||
if isinstance(conditioning_scale, list):
|
||||
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
|
||||
if len(conditioning_scale) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = torch.tensor(conditioning_scale)
|
||||
if isinstance(conditioning_scale, torch.Tensor):
|
||||
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
|
||||
if conditioning_scale.size(0) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
|
||||
|
||||
@@ -908,7 +921,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_channels_latents = (
|
||||
self.transformer.config.in_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.in_channels
|
||||
)
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -976,7 +993,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -19,9 +19,15 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
WanVACEPipeline,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -212,3 +218,35 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
|
||||
def test_inference_with_only_transformer(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = None
|
||||
components["boundary_ratio"] = 0.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
def test_inference_with_only_transformer_2(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = components["transformer"]
|
||||
components["transformer"] = None
|
||||
|
||||
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
|
||||
# because starting timestep t == 1000 == boundary_timestep
|
||||
components["scheduler"] = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
|
||||
)
|
||||
|
||||
components["boundary_ratio"] = 1.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
Reference in New Issue
Block a user