1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix: Add _skip_keys for AutoencoderKLWan (#12523)

add
This commit is contained in:
YiYi Xu
2025-10-22 07:53:13 -10:00
committed by GitHub
parent a0a51eb098
commit bec2d8eaea
3 changed files with 21 additions and 13 deletions

View File

@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x
@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx)
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x + self.avg_shortcut(x_copy)
@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## head
x = self.norm_out(x)
@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx)
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsampler(x)
@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsamplers[0](x)
return x
@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""
_supports_gradient_checkpointing = False
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
@register_to_config
def __init__(

View File

@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_repeated_blocks = []
_parallel_config = None
_cp_plan = None
_skip_keys = None
def __init__(self):
super().__init__()

View File

@@ -866,6 +866,9 @@ def load_sub_model(
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
skip_keys = None
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
skip_keys = loaded_sub_model._skip_keys
if needs_offloading_to_cpu:
dispatch_model(
@@ -874,9 +877,10 @@ def load_sub_model(
device_map=device_map,
force_hooks=True,
main_device=0,
skip_keys=skip_keys,
)
else:
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
return loaded_sub_model