1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Several fixes to Flux ControlNet pipelines (#9472)

* fix flux controlnet pipelines

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
Vladimir Mandic
2024-09-19 21:49:36 -04:00
committed by GitHub
parent 2b443a5d62
commit 14a1b86fc7
4 changed files with 25 additions and 13 deletions

View File

@@ -29,7 +29,14 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
from .flux import (
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -128,6 +135,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
]
)
@@ -143,6 +151,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
]
)

View File

@@ -729,7 +729,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
@@ -763,7 +763,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]
@@ -840,12 +840,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
# controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
@@ -863,6 +861,11 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
return_dict=False,
)
guidance = (
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,

View File

@@ -767,7 +767,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
@@ -798,7 +798,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]

View File

@@ -899,7 +899,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
@@ -933,7 +933,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]