mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# First residual block
|
||||
x = self.resnets[0](x, feat_cache, feat_idx)
|
||||
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
# Process through attention and residual blocks
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
x = attn(x)
|
||||
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
return x
|
||||
|
||||
@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
if self.downsampler is not None:
|
||||
x = self.downsampler(x, feat_cache, feat_idx)
|
||||
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.down_blocks:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
|
||||
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsampler is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsampler(x, feat_cache, feat_idx)
|
||||
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = self.upsampler(x)
|
||||
|
||||
@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
|
||||
"""
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
||||
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = self.upsamplers[0](x)
|
||||
return x
|
||||
@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
|
||||
x = self.conv_in(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
|
||||
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
|
||||
# these are shared mutable state modified in-place
|
||||
_skip_keys = ["feat_cache", "feat_idx"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_repeated_blocks = []
|
||||
_parallel_config = None
|
||||
_cp_plan = None
|
||||
_skip_keys = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -866,6 +866,9 @@ def load_sub_model(
|
||||
# remove hooks
|
||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||
skip_keys = None
|
||||
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
|
||||
skip_keys = loaded_sub_model._skip_keys
|
||||
|
||||
if needs_offloading_to_cpu:
|
||||
dispatch_model(
|
||||
@@ -874,9 +877,10 @@ def load_sub_model(
|
||||
device_map=device_map,
|
||||
force_hooks=True,
|
||||
main_device=0,
|
||||
skip_keys=skip_keys,
|
||||
)
|
||||
else:
|
||||
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
|
||||
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user