1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

flux controlnet fix (control_modes batch & others) (#9507)

* flux controlnet mode to take into account batch size

* incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case

* fix

* fix use_guidance when controlnet is a multi and does not have config

---------

Co-authored-by: Christopher Beckham <christopher.j.beckham@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
YiYi Xu
2024-09-25 19:09:54 -10:00
committed by GitHub
parent 1c6ede9371
commit 9cd37557d5
2 changed files with 37 additions and 25 deletions

View File

@@ -502,16 +502,17 @@ class FluxMultiControlNetModel(ModelMixin):
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
return control_block_samples, control_single_block_samples

View File

@@ -747,10 +747,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
width_control_image,
)
# set control mode
# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
@@ -785,16 +787,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
control_image = control_images
# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
control_modes = []
for cmode in control_mode:
if cmode is None:
cmode = -1
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
control_modes.append(control_mode)
control_mode = control_modes
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -840,9 +848,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
if isinstance(self.controlnet, FluxMultiControlNetModel):
use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
# controlnet