mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -266,6 +266,20 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
|
||||
# we should have handled this in conversion script
|
||||
def _get_pos_embed_from_transformer(self, transformer):
|
||||
pos_embed = PatchEmbed(
|
||||
height=transformer.config.sample_size,
|
||||
width=transformer.config.sample_size,
|
||||
patch_size=transformer.config.patch_size,
|
||||
in_channels=transformer.config.in_channels,
|
||||
embed_dim=transformer.inner_dim,
|
||||
pos_embed_max_size=transformer.config.pos_embed_max_size,
|
||||
)
|
||||
pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True)
|
||||
return pos_embed
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
|
||||
|
||||
@@ -194,6 +194,19 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = SD3MultiControlNetModel(controlnet)
|
||||
if isinstance(controlnet, SD3MultiControlNetModel):
|
||||
for controlnet_model in controlnet.nets:
|
||||
# for SD3.5 8b controlnet, it shares the pos_embed with the transformer
|
||||
if (
|
||||
hasattr(controlnet_model.config, "use_pos_embed")
|
||||
and controlnet_model.config.use_pos_embed is False
|
||||
):
|
||||
pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer)
|
||||
controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device)
|
||||
elif isinstance(controlnet, SD3ControlNetModel):
|
||||
if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
|
||||
pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
|
||||
controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -1042,15 +1055,9 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
if controlnet_config.use_pos_embed is False:
|
||||
# sd35 (offical) 8b controlnet
|
||||
controlnet_model_input = self.transformer.pos_embed(latent_model_input)
|
||||
else:
|
||||
controlnet_model_input = latent_model_input
|
||||
|
||||
# controlnet(s) inference
|
||||
control_block_samples = self.controlnet(
|
||||
hidden_states=controlnet_model_input,
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=controlnet_encoder_hidden_states,
|
||||
pooled_projections=controlnet_pooled_projections,
|
||||
|
||||
Reference in New Issue
Block a user