mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add pipeline_stable_diffusion_3_inpaint.py for SD3 Inference (#8709)
* Add pipeline_stable_diffusion_3_inpaint --------- Co-authored-by: Xu Cao <xucao2@jrehg-work-01.cs.illinois.edu> Co-authored-by: IrohXu <irohcao@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -288,6 +288,7 @@ else:
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
@@ -690,6 +691,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
|
||||
@@ -243,7 +243,11 @@ else:
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_3"] = ["StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline"]
|
||||
_import_structure["stable_diffusion_3"] = [
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
]
|
||||
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
@@ -523,7 +527,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
|
||||
from .stable_diffusion_3 import (
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
|
||||
|
||||
@@ -25,6 +25,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline
|
||||
from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -962,6 +962,21 @@ class StableDiffusion3Img2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3InpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusion3InpaintPipelineFastTests(PipelineLatentTesterMixin, unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = StableDiffusion3InpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=32,
|
||||
patch_size=1,
|
||||
in_channels=16,
|
||||
num_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=4,
|
||||
joint_attention_dim=32,
|
||||
caption_projection_dim=32,
|
||||
pooled_projection_dim=64,
|
||||
out_channels=16,
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=16,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"text_encoder_3": text_encoder_3,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
mask_image = torch.ones((1, 1, 32, 32)).to(device)
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"strength": 0.8,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_3_inpaint_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
inputs["prompt_3"] = "another different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_inpaint_different_negative_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt_2"] = "deformed"
|
||||
inputs["negative_prompt_3"] = "blurry"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_inpaint_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
do_classifier_free_guidance = inputs["guidance_scale"] > 1
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
prompt_3=None,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=torch_device,
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_multi_vae(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user