mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Several fixes to Flux ControlNet pipelines (#9472)
* fix flux controlnet pipelines --------- Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
@@ -29,7 +29,14 @@ from .controlnet import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
|
||||
from .flux import (
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
@@ -128,6 +135,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
("flux", FluxImg2ImgPipeline),
|
||||
("flux-controlnet", FluxControlNetImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -143,6 +151,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
||||
("flux", FluxInpaintPipeline),
|
||||
("flux-controlnet", FluxControlNetInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -729,7 +729,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
@@ -763,7 +763,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
@@ -840,12 +840,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.tensor([guidance_scale], device=device)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
guidance = (
|
||||
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
|
||||
)
|
||||
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
||||
|
||||
# controlnet
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
@@ -863,6 +861,11 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
guidance = (
|
||||
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
|
||||
)
|
||||
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
|
||||
@@ -767,7 +767,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
@@ -798,7 +798,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
|
||||
@@ -899,7 +899,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
@@ -933,7 +933,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user