mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[AnimateDiff+Controlnet] Fix multicontrolnet support (#6551)
* fix multicontrolnet support * update README with multicontrolnet example
This commit is contained in:
@@ -2989,7 +2989,7 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
custom_pipeline="pipeline_animatediff_controlnet",
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
|
||||
)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
@@ -3005,7 +3005,7 @@ result = pipe(
|
||||
width=512,
|
||||
height=768,
|
||||
conditioning_frames=conditioning_frames,
|
||||
num_inference_steps=12,
|
||||
num_inference_steps=20,
|
||||
).frames[0]
|
||||
|
||||
from diffusers.utils import export_to_gif
|
||||
@@ -3029,6 +3029,79 @@ export_to_gif(result.frames[0], "result.gif")
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
You can also use multiple controlnets at once!
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.schedulers import DPMSolverMultistepScheduler
|
||||
from PIL import Image
|
||||
|
||||
motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
|
||||
adapter = MotionAdapter.from_pretrained(motion_id)
|
||||
controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
|
||||
controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
||||
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
motion_adapter=adapter,
|
||||
controlnet=[controlnet1, controlnet2],
|
||||
vae=vae,
|
||||
custom_pipeline="pipeline_animatediff_controlnet",
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
|
||||
)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
def load_video(file_path: str):
|
||||
images = []
|
||||
|
||||
if file_path.startswith(('http://', 'https://')):
|
||||
# If the file_path is a URL
|
||||
response = requests.get(file_path)
|
||||
response.raise_for_status()
|
||||
content = BytesIO(response.content)
|
||||
vid = imageio.get_reader(content)
|
||||
else:
|
||||
# Assuming it's a local file path
|
||||
vid = imageio.get_reader(file_path)
|
||||
|
||||
for frame in vid:
|
||||
pil_image = Image.fromarray(frame)
|
||||
images.append(pil_image)
|
||||
|
||||
return images
|
||||
|
||||
video = load_video("dance.gif")
|
||||
|
||||
# You need to install it using `pip install controlnet_aux`
|
||||
from controlnet_aux.processor import Processor
|
||||
|
||||
p1 = Processor("openpose_full")
|
||||
cn1 = [p1(frame) for frame in video]
|
||||
|
||||
p2 = Processor("canny")
|
||||
cn2 = [p2(frame) for frame in video]
|
||||
|
||||
prompt = "astronaut in space, dancing"
|
||||
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
|
||||
result = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=512,
|
||||
height=768,
|
||||
conditioning_frames=[cn1, cn2],
|
||||
num_inference_steps=20,
|
||||
)
|
||||
|
||||
from diffusers.utils import export_to_gif
|
||||
export_to_gif(result.frames[0], "result.gif")
|
||||
```
|
||||
|
||||
### DemoFusion
|
||||
|
||||
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... custom_pipeline="pipeline_animatediff_controlnet",
|
||||
... ).to(device="cuda", dtype=torch.float16)
|
||||
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
|
||||
... )
|
||||
>>> pipe.enable_vae_slicing()
|
||||
|
||||
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... height=768,
|
||||
... conditioning_frames=conditioning_frames,
|
||||
... num_inference_steps=12,
|
||||
... ).frames[0]
|
||||
... )
|
||||
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
>>> export_to_gif(result.frames[0], "result.gif")
|
||||
@@ -151,7 +151,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
motion_adapter: MotionAdapter,
|
||||
controlnet: Union[ControlNetModel, MultiControlNetModel],
|
||||
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
@@ -166,6 +166,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
super().__init__()
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -488,6 +491,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
@@ -557,31 +561,21 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
if isinstance(image, list):
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
else:
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
if not isinstance(image, list):
|
||||
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}")
|
||||
if len(image) != num_frames:
|
||||
raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}")
|
||||
elif (
|
||||
isinstance(self.controlnet, MultiControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
||||
):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
|
||||
# When `image` is a nested list:
|
||||
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
||||
elif any(isinstance(i, list) for i in image):
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for control_ in image:
|
||||
for image_ in control_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
if not isinstance(image, list) or not isinstance(image[0], list):
|
||||
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}")
|
||||
if len(image[0]) != num_frames:
|
||||
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}")
|
||||
if any(len(img) != len(image[0]) for img in image):
|
||||
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -913,6 +907,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
callback_steps=callback_steps,
|
||||
negative_prompt=negative_prompt,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
@@ -1000,9 +995,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
cond_prepared_frames.append(prepared_frame)
|
||||
|
||||
conditioning_frames = cond_prepared_frames
|
||||
else:
|
||||
assert False
|
||||
|
||||
Reference in New Issue
Block a user