1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[AnimateDiff+Controlnet] Fix multicontrolnet support (#6551)

* fix multicontrolnet support

* update README with multicontrolnet example
This commit is contained in:
Aryan V S
2024-01-15 20:06:54 +05:30
committed by GitHub
parent cb4b3f0b78
commit 119d734f6e
2 changed files with 94 additions and 28 deletions

View File

@@ -2989,7 +2989,7 @@ pipe = DiffusionPipeline.from_pretrained(
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
)
pipe.enable_vae_slicing()
@@ -3005,7 +3005,7 @@ result = pipe(
width=512,
height=768,
conditioning_frames=conditioning_frames,
num_inference_steps=12,
num_inference_steps=20,
).frames[0]
from diffusers.utils import export_to_gif
@@ -3029,6 +3029,79 @@ export_to_gif(result.frames[0], "result.gif")
</tr>
</table>
You can also use multiple controlnets at once!
```python
import torch
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler
from PIL import Image
motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
adapter = MotionAdapter.from_pretrained(motion_id)
controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = DiffusionPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
controlnet=[controlnet1, controlnet2],
vae=vae,
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
)
pipe.enable_vae_slicing()
def load_video(file_path: str):
images = []
if file_path.startswith(('http://', 'https://')):
# If the file_path is a URL
response = requests.get(file_path)
response.raise_for_status()
content = BytesIO(response.content)
vid = imageio.get_reader(content)
else:
# Assuming it's a local file path
vid = imageio.get_reader(file_path)
for frame in vid:
pil_image = Image.fromarray(frame)
images.append(pil_image)
return images
video = load_video("dance.gif")
# You need to install it using `pip install controlnet_aux`
from controlnet_aux.processor import Processor
p1 = Processor("openpose_full")
cn1 = [p1(frame) for frame in video]
p2 = Processor("canny")
cn2 = [p2(frame) for frame in video]
prompt = "astronaut in space, dancing"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=768,
conditioning_frames=[cn1, cn2],
num_inference_steps=20,
)
from diffusers.utils import export_to_gif
export_to_gif(result.frames[0], "result.gif")
```
### DemoFusion
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).

View File

@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
... custom_pipeline="pipeline_animatediff_controlnet",
... ).to(device="cuda", dtype=torch.float16)
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
... )
>>> pipe.enable_vae_slicing()
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
... height=768,
... conditioning_frames=conditioning_frames,
... num_inference_steps=12,
... ).frames[0]
... )
>>> from diffusers.utils import export_to_gif
>>> export_to_gif(result.frames[0], "result.gif")
@@ -151,7 +151,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, MultiControlNetModel],
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -166,6 +166,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -488,6 +491,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
prompt,
height,
width,
num_frames,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
@@ -557,31 +561,21 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if isinstance(image, list):
for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
self.check_image(image, prompt, prompt_embeds)
if not isinstance(image, list):
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}")
if len(image) != num_frames:
raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)
for control_ in image:
for image_ in control_:
self.check_image(image_, prompt, prompt_embeds)
if not isinstance(image, list) or not isinstance(image[0], list):
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}")
if len(image[0]) != num_frames:
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}")
if any(len(img) != len(image[0]) for img in image):
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
else:
assert False
@@ -913,6 +907,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
callback_steps=callback_steps,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -1000,9 +995,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
cond_prepared_frames.append(prepared_frame)
conditioning_frames = cond_prepared_frames
else:
assert False