mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
flux controlnet fix (control_modes batch & others) (#9507)
* flux controlnet mode to take into account batch size * incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case * fix * fix use_guidance when controlnet is a multi and does not have config --------- Co-authored-by: Christopher Beckham <christopher.j.beckham@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -502,16 +502,17 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
|
||||
@@ -747,10 +747,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# set control mode
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if control_mode is not None:
|
||||
if not isinstance(control_mode, int):
|
||||
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
|
||||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
|
||||
|
||||
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
control_images = []
|
||||
@@ -785,16 +787,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
|
||||
control_image = control_images
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
|
||||
raise ValueError(
|
||||
"For Multi-ControlNet, `control_mode` must be a list of the same "
|
||||
+ " length as the number of controlnets (control images) specified"
|
||||
)
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode] * len(control_image)
|
||||
# set control mode
|
||||
control_mode_ = []
|
||||
if isinstance(control_mode, list):
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
control_mode_.append(-1)
|
||||
else:
|
||||
control_mode_.append(cmode)
|
||||
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
control_modes = []
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
cmode = -1
|
||||
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
||||
control_modes.append(control_mode)
|
||||
control_mode = control_modes
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
@@ -840,9 +848,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
guidance = (
|
||||
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
|
||||
)
|
||||
if isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
use_guidance = self.controlnet.nets[0].config.guidance_embeds
|
||||
else:
|
||||
use_guidance = self.controlnet.config.guidance_embeds
|
||||
|
||||
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
|
||||
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
||||
|
||||
# controlnet
|
||||
|
||||
Reference in New Issue
Block a user