mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Kolors] Add PAG (#8934)
* txt2img pag added * autopipe added, fixed case * style * apply suggestions * added fast tests, added todo tests * revert dummy objects for kolors * fix pag dummies * fix test imports * update pag tests * add kolor pag to docs --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -43,6 +43,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## KolorsPAGPipeline
|
||||
[[autodoc]] KolorsPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPAGPipeline
|
||||
[[autodoc]] StableDiffusionPAGPipeline
|
||||
- all
|
||||
|
||||
@@ -280,8 +280,6 @@ else:
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"KolorsPipeline",
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"LattePipeline",
|
||||
@@ -397,7 +395,7 @@ except OptionalDependencyNotAvailable:
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPipeline"])
|
||||
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
@@ -820,7 +818,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import KolorsImg2ImgPipeline, KolorsPipeline
|
||||
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -146,6 +146,7 @@ else:
|
||||
_import_structure["pag"].extend(
|
||||
[
|
||||
"AnimateDiffPAGPipeline",
|
||||
"KolorsPAGPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
@@ -540,6 +541,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pag import (
|
||||
AnimateDiffPAGPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
KolorsPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
|
||||
@@ -162,8 +162,10 @@ _AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .kolors import KolorsPipeline
|
||||
from .pag import KolorsPAGPipeline
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
|
||||
SUPPORTED_TASKS_MAPPINGS = [
|
||||
|
||||
@@ -143,10 +143,18 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@unk_token.setter
|
||||
def unk_token(self, value: str):
|
||||
self._unk_token = value
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@pad_token.setter
|
||||
def pad_token(self, value: str):
|
||||
self._pad_token = value
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
@@ -155,6 +163,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@eos_token.setter
|
||||
def eos_token(self, value: str):
|
||||
self._eos_token = value
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@@ -25,6 +25,7 @@ else:
|
||||
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
||||
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
|
||||
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
|
||||
@@ -44,6 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
||||
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
||||
from .pipeline_pag_kolors import KolorsPAGPipeline
|
||||
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
|
||||
from .pipeline_pag_sd import StableDiffusionPAGPipeline
|
||||
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
|
||||
|
||||
1136
src/diffusers/pipelines/pag/pipeline_pag_kolors.py
Normal file
1136
src/diffusers/pipelines/pag/pipeline_pag_kolors.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,6 +17,21 @@ class KolorsImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
|
||||
class KolorsPAGPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "sentencepiece"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
|
||||
class KolorsPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "sentencepiece"]
|
||||
|
||||
|
||||
@@ -133,23 +133,11 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
||||
# not sure if it is worth to fix it before integrating it to transformers
|
||||
def test_save_load_optional_components(self):
|
||||
# TODO (Alvaro) need to fix later
|
||||
pass
|
||||
super().test_save_load_optional_components(expected_max_difference=2e-4)
|
||||
|
||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
||||
# not sure if it is worth to fix it before integrating it to transformers
|
||||
def test_save_load_float16(self):
|
||||
# TODO (Alvaro) need to fix later
|
||||
pass
|
||||
|
||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
||||
# not sure if it is worth to fix it before integrating it to transformers
|
||||
def test_save_load_local(self):
|
||||
# TODO (Alvaro) need to fix later
|
||||
pass
|
||||
super().test_save_load_float16(expected_max_diff=2e-1)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=5e-4)
|
||||
self._test_inference_batch_single_identical(expected_max_diff=5e-4)
|
||||
|
||||
152
tests/pipelines/kolors/test_kolors_img2img.py
Normal file
152
tests/pipelines/kolors/test_kolors_img2img.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
EulerDiscreteScheduler,
|
||||
KolorsImg2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = KolorsImg2ImgPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(2, 4),
|
||||
layers_per_block=2,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=56,
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=1,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"strength": 0.8,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertEqual(image.shape, (1, 64, 64, 3))
|
||||
expected_slice = np.array(
|
||||
[0.54823864, 0.43654007, 0.4886489, 0.63072854, 0.53641886, 0.4896852, 0.62123513, 0.5621531, 0.42809626]
|
||||
)
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=7e-2)
|
||||
252
tests/pipelines/pag/test_pag_kolors.py
Normal file
252
tests/pipelines/pag/test_pag_kolors.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
EulerDiscreteScheduler,
|
||||
KolorsPAGPipeline,
|
||||
KolorsPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineFromPipeTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class KolorsPAGPipelineFastTests(
|
||||
PipelineTesterMixin,
|
||||
PipelineFromPipeTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = KolorsPAGPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(2, 4),
|
||||
layers_per_block=2,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=56,
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=1,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"pag_scale": 0.9,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_pag_disable_enable(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# base pipeline (expect same output when pag is disabled)
|
||||
pipe_sd = KolorsPipeline(**components)
|
||||
pipe_sd = pipe_sd.to(device)
|
||||
pipe_sd.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
pipe_pag = self.pipeline_class(**components)
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["pag_scale"] = 0.0
|
||||
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag enabled
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
|
||||
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
|
||||
|
||||
def test_pag_applied_layers(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# base pipeline
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers
|
||||
all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k]
|
||||
original_attn_procs = pipe.unet.attn_processors
|
||||
pag_layers = ["mid", "down", "up"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
|
||||
|
||||
all_self_attn_mid_layers = [
|
||||
"mid_block.attentions.0.transformer_blocks.0.attn1.processor",
|
||||
"mid_block.attentions.0.transformer_blocks.1.attn1.processor",
|
||||
]
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["mid"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
|
||||
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["mid_block"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
|
||||
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["mid_block.attentions.0"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
|
||||
|
||||
# pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["mid_block.attentions.1"]
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
|
||||
# pag_applied_layers = "down" should apply to all self-attention layers in down_blocks
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["down"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert len(pipe.pag_attn_processors) == 4
|
||||
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["down_blocks.0"]
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["down_blocks.1"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert len(pipe.pag_attn_processors) == 4
|
||||
|
||||
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["down_blocks.1.attentions.1"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert len(pipe.pag_attn_processors) == 2
|
||||
|
||||
def test_pag_inference(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe_pag(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (
|
||||
1,
|
||||
64,
|
||||
64,
|
||||
3,
|
||||
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
|
||||
expected_slice = np.array(
|
||||
[0.26030684, 0.43192005, 0.4042826, 0.4189067, 0.5181305, 0.3832534, 0.472135, 0.4145031, 0.43726248]
|
||||
)
|
||||
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
@@ -26,6 +26,7 @@ from diffusers import (
|
||||
ConsistencyDecoderVAE,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
KolorsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
@@ -656,6 +657,8 @@ class PipelineFromPipeTesterMixin:
|
||||
def original_pipeline_class(self):
|
||||
if "xl" in self.pipeline_class.__name__.lower():
|
||||
original_pipeline_class = StableDiffusionXLPipeline
|
||||
elif "kolors" in self.pipeline_class.__name__.lower():
|
||||
original_pipeline_class = KolorsPipeline
|
||||
else:
|
||||
original_pipeline_class = StableDiffusionPipeline
|
||||
|
||||
@@ -681,6 +684,9 @@ class PipelineFromPipeTesterMixin:
|
||||
elif self.original_pipeline_class == StableDiffusionXLPipeline:
|
||||
original_repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
original_kwargs = {"requires_aesthetics_score": True, "force_zeros_for_empty_prompt": False}
|
||||
elif self.original_pipeline_class == KolorsPipeline:
|
||||
original_repo = "hf-internal-testing/tiny-kolors-pipe"
|
||||
original_kwargs = {"force_zeros_for_empty_prompt": False}
|
||||
else:
|
||||
raise ValueError(
|
||||
"original_pipeline_class must be either StableDiffusionPipeline or StableDiffusionXLPipeline"
|
||||
|
||||
Reference in New Issue
Block a user