1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix offloading for sd3.5 controlnets (#10072)

* add
This commit is contained in:
YiYi Xu
2024-12-02 10:11:25 -10:00
committed by GitHub
parent c44fba8899
commit cd344393e2
2 changed files with 28 additions and 7 deletions

View File

@@ -266,6 +266,20 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
# we should have handled this in conversion script
def _get_pos_embed_from_transformer(self, transformer):
pos_embed = PatchEmbed(
height=transformer.config.sample_size,
width=transformer.config.sample_size,
patch_size=transformer.config.patch_size,
in_channels=transformer.config.in_channels,
embed_dim=transformer.inner_dim,
pos_embed_max_size=transformer.config.pos_embed_max_size,
)
pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True)
return pos_embed
@classmethod
def from_transformer(
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True

View File

@@ -194,6 +194,19 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
super().__init__()
if isinstance(controlnet, (list, tuple)):
controlnet = SD3MultiControlNetModel(controlnet)
if isinstance(controlnet, SD3MultiControlNetModel):
for controlnet_model in controlnet.nets:
# for SD3.5 8b controlnet, it shares the pos_embed with the transformer
if (
hasattr(controlnet_model.config, "use_pos_embed")
and controlnet_model.config.use_pos_embed is False
):
pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer)
controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device)
elif isinstance(controlnet, SD3ControlNetModel):
if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)
self.register_modules(
vae=vae,
@@ -1042,15 +1055,9 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
if controlnet_config.use_pos_embed is False:
# sd35 (offical) 8b controlnet
controlnet_model_input = self.transformer.pos_embed(latent_model_input)
else:
controlnet_model_input = latent_model_input
# controlnet(s) inference
control_block_samples = self.controlnet(
hidden_states=controlnet_model_input,
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=controlnet_encoder_hidden_states,
pooled_projections=controlnet_pooled_projections,