mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Pipiline] Wuerstchen v3 aka Stable Cascasde pipeline (#6487)
* initial diffNext v3 * move to v3 folder * imports * dry up the unets * no switch_level * fix init * add switch_level tp config * Fixed some things * Added pooled text embeddings * Initial work on adding image encoder * changes from @dome272 * Stuff for the image encoder processing and variable naming in decoder * fix arg name * inference fixes * inference fixes * default TimestepBlock without conds * c_skip=0 by default * fix bfloat16 to cpu * use config * undo temp change * fix gen_c_embeddings args * change text encoding * text encoding * undo print * undo .gitignore change * Allow WuerstchenV3PriorPipeline to use the base DDPM & DDIM schedulers * use WuerstchenV3Unet in both pipelines * fix imports * initial failing tests * cleanup * use scheduler.timesterps * some fixes to the tests, still not fully working * fix tests * fix prior tests * add dropout to the model_kwargs * more tests passing * update expected_slice * initial rename * rename tests * rename class names * make fix-copies * initial docs * autodocs * typos * fix arg docs * add text_encoder info * combined pipeline has optional image arg * fix documentation * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * use self.config * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * c_in -> in_channels * removed kwargs from unet's forward * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * remove older callback api * removed kwargs and fixed decoder guidance > 1 * decoder takes emeds * check and use image_embeds * fixed all but one decoder test * fix decoder tests * update callback api * fix some more combined tests * push combined pipeline * initial docs * fix doc_string * update combined api * no test_callback_inputs test for combined pipeline * add optional components * fix ordering of components * fix combined tests * update convert script * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * fix imports * move effnet out of deniosing loop * prompt_embeds_pooled only when doing guidance * Fix repeat shape * move StableCascadeUnet to models/unets/ * more descriptive names * converted when numpy() * StableCascadePriorPipelineOutput docs * rename StableCascadeUNet * add slow tests * fix slow tests * update * update * updated model_path * add args for weights * set push_to_hub to false * update * update * update * update * update * update * update * update * update * update * update * update * update * update --------- Co-authored-by: Dominic Rampas <d6582533@gmail.com> Co-authored-by: Pablo Pernias <pablo@pernias.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: 99991 <99991@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -318,6 +318,8 @@
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/stable_cascade
|
||||
title: Stable Cascade
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
|
||||
88
docs/source/en/api/pipelines/stable_cascade.md
Normal file
88
docs/source/en/api/pipelines/stable_cascade.md
Normal file
@@ -0,0 +1,88 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable Cascade
|
||||
|
||||
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
|
||||
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
|
||||
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
|
||||
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
|
||||
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
|
||||
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
|
||||
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
|
||||
Diffusion 1.5.
|
||||
|
||||
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
|
||||
like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.
|
||||
|
||||
The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).
|
||||
|
||||
## Model Overview
|
||||
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
|
||||
hence the name "Stable Cascade".
|
||||
|
||||
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
|
||||
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
|
||||
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
|
||||
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
|
||||
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
|
||||
for generating the small 24 x 24 latents given a text prompt.
|
||||
|
||||
## Uses
|
||||
|
||||
### Direct Use
|
||||
|
||||
The model is intended for research purposes for now. Possible research areas and tasks include
|
||||
|
||||
- Research on generative models.
|
||||
- Safe deployment of models which have the potential to generate harmful content.
|
||||
- Probing and understanding the limitations and biases of generative models.
|
||||
- Generation of artworks and use in design and other artistic processes.
|
||||
- Applications in educational or creative tools.
|
||||
|
||||
Excluded uses are described below.
|
||||
|
||||
### Out-of-Scope Use
|
||||
|
||||
The model was not trained to be factual or true representations of people or events,
|
||||
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
|
||||
|
||||
## Limitations and Bias
|
||||
|
||||
### Limitations
|
||||
- Faces and people in general may not be generated properly.
|
||||
- The autoencoding part of the model is lossy.
|
||||
|
||||
|
||||
## StableCascadeCombinedPipeline
|
||||
|
||||
[[autodoc]] StableCascadeCombinedPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableCascadePriorPipeline
|
||||
|
||||
[[autodoc]] StableCascadePriorPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableCascadePriorPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput
|
||||
|
||||
## StableCascadeDecoderPipeline
|
||||
|
||||
[[autodoc]] StableCascadeDecoderPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
215
scripts/convert_stable_cascade.py
Normal file
215
scripts/convert_stable_cascade.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
CLIPConfig,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
DDPMWuerstchenScheduler,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
||||
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
|
||||
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
|
||||
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
||||
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.model_path
|
||||
|
||||
device = "cpu"
|
||||
|
||||
# set paths to model weights
|
||||
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
||||
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
|
||||
|
||||
# Clip Text encoder and tokenizer
|
||||
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
config.text_config.projection_dim = config.projection_dim
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
# image processor
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=2048,
|
||||
block_out_channels=[2048, 2048],
|
||||
num_attention_heads=[32, 32],
|
||||
down_num_layers_per_block=[8, 24],
|
||||
up_num_layers_per_block=[24, 8],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
load_model_dict_into_meta(prior_model, state_dict)
|
||||
|
||||
# scheduler for prior and decoder
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
# rename clip_mapper to clip_txt_pooled_mapper
|
||||
elif key.endswith("clip_mapper.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
||||
elif key.endswith("clip_mapper.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 640, 1280, 1280],
|
||||
down_num_layers_per_block=[2, 6, 28, 6],
|
||||
up_num_layers_per_block=[6, 28, 6, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
||||
num_attention_heads=[0, 0, 20, 20],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
load_model_dict_into_meta(decoder, state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
|
||||
@@ -86,6 +86,7 @@ else:
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
"PriorTransformer",
|
||||
"StableCascadeUNet",
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
"Transformer2DModel",
|
||||
@@ -259,6 +260,9 @@ else:
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
@@ -626,6 +630,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
|
||||
@@ -47,6 +47,7 @@ if is_torch_available():
|
||||
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
StableCascadeUNet,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
|
||||
@@ -10,6 +10,7 @@ if is_torch_available():
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .unet_stable_cascade import StableCascadeUNet
|
||||
from .uvit_2d import UVit2DModel
|
||||
|
||||
|
||||
|
||||
609
src/diffusers/models/unets/unet_stable_cascade.py
Normal file
609
src/diffusers/models/unets/unet_stable_cascade.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention_processor import Attention
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm
|
||||
class SDCascadeLayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = super().forward(x)
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class SDCascadeTimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=[]):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
self.mapper = linear_cls(c_timestep, c * 2)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||
for i, c in enumerate(self.conds):
|
||||
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||
a, b = a + ac, b + bc
|
||||
return x * (1 + a) + b
|
||||
|
||||
|
||||
class SDCascadeResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c + c_skip, c * 4),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(c * 4, c),
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
|
||||
# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * stand_div_norm) + self.beta + x
|
||||
|
||||
|
||||
class SDCascadeAttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.self_attn = self_attn
|
||||
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
norm_x = self.norm(x)
|
||||
if self.self_attn:
|
||||
batch_size, channel, _, _ = x.shape
|
||||
kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
|
||||
x = x + self.attention(norm_x, encoder_hidden_states=kv)
|
||||
return x
|
||||
|
||||
|
||||
class UpDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mode, enabled=True):
|
||||
super().__init__()
|
||||
if mode not in ["up", "down"]:
|
||||
raise ValueError(f"{mode} not supported")
|
||||
interpolation = (
|
||||
nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True)
|
||||
if enabled
|
||||
else nn.Identity()
|
||||
)
|
||||
mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableCascadeUNetOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class StableCascadeUNet(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
timestep_ratio_embedding_dim: int = 64,
|
||||
patch_size: int = 1,
|
||||
conditioning_dim: int = 2048,
|
||||
block_out_channels: Tuple[int] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int] = (24, 8),
|
||||
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
||||
1,
|
||||
1,
|
||||
),
|
||||
up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1),
|
||||
block_types_per_layer: Tuple[Tuple[str]] = (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
),
|
||||
clip_text_in_channels: Optional[int] = None,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels: Optional[int] = None,
|
||||
clip_seq=4,
|
||||
effnet_in_channels: Optional[int] = None,
|
||||
pixel_mapper_in_channels: Optional[int] = None,
|
||||
kernel_size=3,
|
||||
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
||||
self_attn: Union[bool, Tuple[bool]] = True,
|
||||
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
|
||||
switch_level: Optional[Tuple[bool]] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`, defaults to 16):
|
||||
Number of channels in the input sample.
|
||||
out_channels (`int`, defaults to 16):
|
||||
Number of channels in the output sample.
|
||||
timestep_ratio_embedding_dim (`int`, defaults to 64):
|
||||
Dimension of the projected time embedding.
|
||||
patch_size (`int`, defaults to 1):
|
||||
Patch size to use for pixel unshuffling layer
|
||||
conditioning_dim (`int`, defaults to 2048):
|
||||
Dimension of the image and text conditional embedding.
|
||||
block_out_channels (Tuple[int], defaults to (2048, 2048)):
|
||||
Tuple of output channels for each block.
|
||||
num_attention_heads (Tuple[int], defaults to (32, 32)):
|
||||
Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention.
|
||||
down_num_layers_per_block (Tuple[int], defaults to [8, 24]):
|
||||
Number of layers in each down block.
|
||||
up_num_layers_per_block (Tuple[int], defaults to [24, 8]):
|
||||
Number of layers in each up block.
|
||||
down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
|
||||
Number of 1x1 Convolutional layers to repeat in each down block.
|
||||
up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
|
||||
Number of 1x1 Convolutional layers to repeat in each up block.
|
||||
block_types_per_layer (Tuple[Tuple[str]], optional,
|
||||
defaults to (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
|
||||
):
|
||||
Block types used in each layer of the up/down blocks.
|
||||
clip_text_in_channels (`int`, *optional*, defaults to `None`):
|
||||
Number of input channels for CLIP based text conditioning.
|
||||
clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280):
|
||||
Number of input channels for pooled CLIP text embeddings.
|
||||
clip_image_in_channels (`int`, *optional*):
|
||||
Number of input channels for CLIP based image conditioning.
|
||||
clip_seq (`int`, *optional*, defaults to 4):
|
||||
effnet_in_channels (`int`, *optional*, defaults to `None`):
|
||||
Number of input channels for effnet conditioning.
|
||||
pixel_mapper_in_channels (`int`, defaults to `None`):
|
||||
Number of input channels for pixel mapper conditioning.
|
||||
kernel_size (`int`, *optional*, defaults to 3):
|
||||
Kernel size to use in the block convolutional layers.
|
||||
dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)):
|
||||
Dropout to use per block.
|
||||
self_attn (Union[bool, Tuple[bool]]):
|
||||
Tuple of booleans that determine whether to use self attention in a block or not.
|
||||
timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")):
|
||||
Timestep conditioning type.
|
||||
switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`):
|
||||
Tuple that indicates whether upsampling or downsampling should be applied in a block
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if len(block_out_channels) != len(down_num_layers_per_block):
|
||||
raise ValueError(
|
||||
f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(up_num_layers_per_block):
|
||||
raise ValueError(
|
||||
f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(down_blocks_repeat_mappers):
|
||||
raise ValueError(
|
||||
f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(up_blocks_repeat_mappers):
|
||||
raise ValueError(
|
||||
f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(block_types_per_layer):
|
||||
raise ValueError(
|
||||
f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
if isinstance(dropout, float):
|
||||
dropout = (dropout,) * len(block_out_channels)
|
||||
if isinstance(self_attn, bool):
|
||||
self_attn = (self_attn,) * len(block_out_channels)
|
||||
|
||||
# CONDITIONING
|
||||
if effnet_in_channels is not None:
|
||||
self.effnet_mapper = nn.Sequential(
|
||||
nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
if pixel_mapper_in_channels is not None:
|
||||
self.pixels_mapper = nn.Sequential(
|
||||
nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
|
||||
self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq)
|
||||
if clip_text_in_channels is not None:
|
||||
self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim)
|
||||
if clip_image_in_channels is not None:
|
||||
self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq)
|
||||
self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
|
||||
def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == "SDCascadeResBlock":
|
||||
return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout)
|
||||
elif block_type == "SDCascadeAttnBlock":
|
||||
return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout)
|
||||
elif block_type == "SDCascadeTimestepBlock":
|
||||
return SDCascadeTimestepBlock(
|
||||
in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block type {block_type} not supported")
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(block_out_channels)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(
|
||||
nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(
|
||||
block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1]
|
||||
)
|
||||
if switch_level is not None
|
||||
else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(down_num_layers_per_block[i]):
|
||||
for block_type in block_types_per_layer[i]:
|
||||
block = get_block(
|
||||
block_type,
|
||||
block_out_channels[i],
|
||||
num_attention_heads[i],
|
||||
dropout=dropout[i],
|
||||
self_attn=self_attn[i],
|
||||
)
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
if down_blocks_repeat_mappers is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(down_blocks_repeat_mappers[i] - 1):
|
||||
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(block_out_channels))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(
|
||||
nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(
|
||||
block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1]
|
||||
)
|
||||
if switch_level is not None
|
||||
else nn.ConvTranspose2d(
|
||||
block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(up_num_layers_per_block[::-1][i]):
|
||||
for k, block_type in enumerate(block_types_per_layer[i]):
|
||||
c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0
|
||||
block = get_block(
|
||||
block_type,
|
||||
block_out_channels[i],
|
||||
num_attention_heads[i],
|
||||
c_skip=c_skip,
|
||||
dropout=dropout[i],
|
||||
self_attn=self_attn[i],
|
||||
)
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
if up_blocks_repeat_mappers is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(up_blocks_repeat_mappers[::-1][i] - 1):
|
||||
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02)
|
||||
nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None
|
||||
nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None
|
||||
|
||||
if hasattr(self, "effnet_mapper"):
|
||||
nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
||||
nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
||||
|
||||
if hasattr(self, "pixels_mapper"):
|
||||
nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
||||
nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
|
||||
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
|
||||
# blocks
|
||||
for level_block in self.down_blocks + self.up_blocks:
|
||||
for block in level_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0]))
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
nn.init.constant_(block.mapper.weight, 0)
|
||||
|
||||
def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000):
|
||||
r = timestep_ratio * max_positions
|
||||
half_dim = self.config.timestep_ratio_embedding_dim // 2
|
||||
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
|
||||
if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode="constant")
|
||||
|
||||
return emb.to(dtype=r.dtype)
|
||||
|
||||
def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None):
|
||||
if len(clip_txt_pooled.shape) == 2:
|
||||
clip_txt_pool = clip_txt_pooled.unsqueeze(1)
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
|
||||
clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1
|
||||
)
|
||||
if clip_txt is not None and clip_img is not None:
|
||||
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||
if len(clip_img.shape) == 2:
|
||||
clip_img = clip_img.unsqueeze(1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(
|
||||
clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1
|
||||
)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
else:
|
||||
clip = clip_txt_pool
|
||||
return self.clip_norm(clip)
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), use_reentrant=False
|
||||
)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
else:
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
x = block(x)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, skip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
else:
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample,
|
||||
timestep_ratio,
|
||||
clip_text_pooled,
|
||||
clip_text=None,
|
||||
clip_img=None,
|
||||
effnet=None,
|
||||
pixels=None,
|
||||
sca=None,
|
||||
crp=None,
|
||||
return_dict=True,
|
||||
):
|
||||
if pixels is None:
|
||||
pixels = sample.new_zeros(sample.size(0), 3, 8, 8)
|
||||
|
||||
# Process the conditioning embeddings
|
||||
timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio)
|
||||
for c in self.config.timestep_conditioning_type:
|
||||
if c == "sca":
|
||||
cond = sca
|
||||
elif c == "crp":
|
||||
cond = crp
|
||||
else:
|
||||
cond = None
|
||||
t_cond = cond or torch.zeros_like(timestep_ratio)
|
||||
timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1)
|
||||
clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(sample)
|
||||
if hasattr(self, "effnet_mapper") and effnet is not None:
|
||||
x = x + self.effnet_mapper(
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
|
||||
)
|
||||
if hasattr(self, "pixels_mapper"):
|
||||
x = x + nn.functional.interpolate(
|
||||
self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
|
||||
x = self._up_decode(level_outputs, timestep_ratio_embed, clip)
|
||||
sample = self.clf(x)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
return StableCascadeUNetOutput(sample=sample)
|
||||
@@ -176,6 +176,11 @@ else:
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_cascade"] = [
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
@@ -424,6 +429,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pixart_alpha import PixArtAlphaPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_cascade import (
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
|
||||
50
src/diffusers/pipelines/stable_cascade/__init__.py
Normal file
50
src/diffusers/pipelines/stable_cascade/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"]
|
||||
_import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"]
|
||||
_import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline
|
||||
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,465 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
|
||||
|
||||
>>> prior_pipe = StableCascadePriorPipeline.from_pretrained(
|
||||
... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16
|
||||
... ).to("cuda")
|
||||
>>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain(
|
||||
... "stabilityai/stable-cascade", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> prior_output = pipe(prompt)
|
||||
>>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating images from the Stable Cascade model.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The CLIP tokenizer.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The CLIP text encoder.
|
||||
decoder ([`StableCascadeUNet`]):
|
||||
The Stable Cascade decoder unet.
|
||||
vqgan ([`PaellaVQModel`]):
|
||||
The VQGAN model.
|
||||
scheduler ([`DDPMWuerstchenScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
latent_dim_scale (float, `optional`, defaults to 10.67):
|
||||
Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
|
||||
height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
|
||||
width=int(24*10.67)=256 in order to match the training conditions.
|
||||
"""
|
||||
|
||||
unet_name = "decoder"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds_pooled",
|
||||
"negative_prompt_embeds",
|
||||
"image_embeddings",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: StableCascadeUNet,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
latent_dim_scale: float = 10.67,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
decoder=decoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
)
|
||||
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
||||
|
||||
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
|
||||
batch_size, channels, height, width = image_embeddings.shape
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
4,
|
||||
int(height * self.config.latent_dim_scale),
|
||||
int(width * self.config.latent_dim_scale),
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
device,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = text_encoder_output.hidden_states[-1]
|
||||
if prompt_embeds_pooled is None:
|
||||
prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if negative_prompt_embeds is None and do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=uncond_input.attention_mask.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
seq_len = negative_prompt_embeds_pooled.shape[1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(
|
||||
dtype=self.text_encoder.dtype, device=device
|
||||
)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
# done duplicates
|
||||
|
||||
return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 10,
|
||||
guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
Image Embeddings either extracted from an image or generated by a Prior Model.
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 12):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
|
||||
linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `decoder_guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image
|
||||
embeddings.
|
||||
"""
|
||||
|
||||
# 0. Define commonly used variables
|
||||
device = self._execution_device
|
||||
dtype = self.decoder.dtype
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
if isinstance(image_embeddings, list):
|
||||
image_embeddings = torch.cat(image_embeddings, dim=0)
|
||||
batch_size = image_embeddings.shape[0]
|
||||
|
||||
# 2. Encode caption
|
||||
if prompt_embeds is None and negative_prompt_embeds is None:
|
||||
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
prompt_embeds_pooled = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
|
||||
)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
if self.do_classifier_free_guidance
|
||||
else image_embeddings
|
||||
)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latents
|
||||
latents = self.prepare_latents(
|
||||
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
# 6. Run denoising loop
|
||||
self._num_timesteps = len(timesteps[:-1])
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
|
||||
# 7. Denoise latents
|
||||
predicted_latents = self.decoder(
|
||||
sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
||||
timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
|
||||
clip_text_pooled=prompt_embeds_pooled,
|
||||
effnet=effnet,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# 8. Check for classifier free guidance and apply it
|
||||
if self.do_classifier_free_guidance:
|
||||
predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
|
||||
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
|
||||
|
||||
# 9. Renoise latents to next timestep
|
||||
latents = self.scheduler.step(
|
||||
model_output=predicted_latents,
|
||||
timestep=timestep_ratio,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if output_type not in ["pt", "np", "pil", "latent"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if not output_type == "latent":
|
||||
# 10. Scale and decode the image latents with vq-vae
|
||||
latents = self.vqgan.config.scale_factor * latents
|
||||
images = self.vqgan.decode(latents).sample.clamp(0, 1)
|
||||
if output_type == "np":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
elif output_type == "pil":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
images = self.numpy_to_pil(images)
|
||||
else:
|
||||
images = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return images
|
||||
return ImagePipelineOutput(images)
|
||||
@@ -0,0 +1,294 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
|
||||
|
||||
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusions import StableCascadeCombinedPipeline
|
||||
|
||||
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to(
|
||||
... "cuda"
|
||||
... )
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> images = pipe(prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Combined Pipeline for text-to-image generation using Stable Cascade.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The decoder tokenizer to be used for text inputs.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The decoder text encoder to be used for text inputs.
|
||||
decoder (`StableCascadeUNet`):
|
||||
The decoder model to be used for decoder image generation pipeline.
|
||||
scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for decoder image generation pipeline.
|
||||
vqgan (`PaellaVQModel`):
|
||||
The VQGAN model to be used for decoder image generation pipeline.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
prior_prior (`StableCascadeUNet`):
|
||||
The prior model to be used for prior pipeline.
|
||||
prior_scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for prior pipeline.
|
||||
"""
|
||||
|
||||
_load_connected_pipes = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
decoder: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
prior_prior: StableCascadeUNet,
|
||||
prior_text_encoder: CLIPTextModel,
|
||||
prior_tokenizer: CLIPTokenizer,
|
||||
prior_scheduler: DDPMWuerstchenScheduler,
|
||||
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
prior_text_encoder=prior_text_encoder,
|
||||
prior_tokenizer=prior_tokenizer,
|
||||
prior_prior=prior_prior,
|
||||
prior_scheduler=prior_scheduler,
|
||||
prior_feature_extractor=prior_feature_extractor,
|
||||
prior_image_encoder=prior_image_encoder,
|
||||
)
|
||||
self.prior_pipe = StableCascadePriorPipeline(
|
||||
prior=prior_prior,
|
||||
text_encoder=prior_text_encoder,
|
||||
tokenizer=prior_tokenizer,
|
||||
scheduler=prior_scheduler,
|
||||
image_encoder=prior_image_encoder,
|
||||
feature_extractor=prior_feature_extractor,
|
||||
)
|
||||
self.decoder_pipe = StableCascadeDecoderPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
||||
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
||||
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
||||
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self.prior_pipe.set_progress_bar_config(**kwargs)
|
||||
self.decoder_pipe.set_progress_bar_config(**kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
prior_num_inference_steps: int = 60,
|
||||
prior_timesteps: Optional[List[float]] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
num_inference_steps: int = 12,
|
||||
decoder_timesteps: Optional[List[float]] = None,
|
||||
decoder_guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation for the prior and decoder.
|
||||
images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*):
|
||||
The images to guide the image generation for the prior.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
|
||||
to the text `prompt`, usually at the expense of lower image quality.
|
||||
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
|
||||
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`prior_timesteps`
|
||||
num_inference_steps (`int`, *optional*, defaults to 12):
|
||||
The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
|
||||
the expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`timesteps`
|
||||
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
prior_callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
|
||||
int, callback_kwargs: Dict)`.
|
||||
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
|
||||
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
|
||||
the `._callback_tensor_inputs` attribute of your pipeine class.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
prior_outputs = self.prior_pipe(
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
images=images,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=prior_num_inference_steps,
|
||||
guidance_scale=prior_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type="pt",
|
||||
return_dict=True,
|
||||
callback_on_step_end=prior_callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
image_embeddings = prior_outputs.image_embeddings
|
||||
prompt_embeds = prior_outputs.get("prompt_embeds", None)
|
||||
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
|
||||
|
||||
outputs = self.decoder_pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=decoder_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
generator=generator,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
@@ -0,0 +1,614 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableCascadePriorPipeline
|
||||
|
||||
>>> prior_pipe = StableCascadePriorPipeline.from_pretrained(
|
||||
... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> prior_output = pipe(prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableCascadePriorPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for WuerstchenPriorPipeline.
|
||||
|
||||
Args:
|
||||
image_embeddings (`torch.FloatTensor` or `np.ndarray`)
|
||||
Prior image embeddings for text prompt
|
||||
prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings for the prompt.
|
||||
negative_prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings for the negative prompt.
|
||||
"""
|
||||
|
||||
image_embeddings: Union[torch.FloatTensor, np.ndarray]
|
||||
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
|
||||
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
|
||||
|
||||
|
||||
class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating image prior for Stable Cascade.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`StableCascadeUNet`]):
|
||||
The Stable Cascade prior to approximate the image embedding from the text and/or image embedding.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`DDPMWuerstchenScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
resolution_multiple ('float', *optional*, defaults to 42.67):
|
||||
Default resolution for multiple images generated.
|
||||
"""
|
||||
|
||||
unet_name = "prior"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "image_encoder->text_encoder->prior"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
prior: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
resolution_multiple: float = 42.67,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
prior=prior,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(resolution_multiple=resolution_multiple)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler
|
||||
):
|
||||
latent_shape = (
|
||||
num_images_per_prompt * batch_size,
|
||||
self.prior.config.in_channels,
|
||||
ceil(height / self.config.resolution_multiple),
|
||||
ceil(width / self.config.resolution_multiple),
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != latent_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
device,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = text_encoder_output.hidden_states[-1]
|
||||
if prompt_embeds_pooled is None:
|
||||
prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if negative_prompt_embeds is None and do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=uncond_input.attention_mask.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
seq_len = negative_prompt_embeds_pooled.shape[1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(
|
||||
dtype=self.text_encoder.dtype, device=device
|
||||
)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
# done duplicates
|
||||
|
||||
return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled
|
||||
|
||||
def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt):
|
||||
image_embeds = []
|
||||
for image in images:
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embed = self.image_encoder(image).image_embeds.unsqueeze(1)
|
||||
image_embeds.append(image_embed)
|
||||
image_embeds = torch.cat(image_embeds, dim=1)
|
||||
|
||||
image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)
|
||||
negative_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, negative_image_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
images=None,
|
||||
image_embeds=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
prompt_embeds_pooled=None,
|
||||
negative_prompt_embeds=None,
|
||||
negative_prompt_embeds_pooled=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
|
||||
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed"
|
||||
f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !="
|
||||
f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}."
|
||||
)
|
||||
|
||||
if image_embeds is not None and images is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
|
||||
if images:
|
||||
for i, image in enumerate(images):
|
||||
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
||||
raise TypeError(
|
||||
f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got"
|
||||
f"{type(image)} for image number {i}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def get_t_condioning(self, t, alphas_cumprod):
|
||||
s = torch.tensor([0.003])
|
||||
clamp_range = [0, 1]
|
||||
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
||||
var = alphas_cumprod[t]
|
||||
var = var.clamp(*clamp_range)
|
||||
s, min_var = s.to(var.device), min_var.to(var.device)
|
||||
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return ratio
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
num_inference_steps: int = 20,
|
||||
timesteps: List[float] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
image_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pt",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 60):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 8.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
|
||||
linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `decoder_guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
|
||||
argument.
|
||||
image_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting.
|
||||
If not provided, image embeddings will be generated from `image` input argument if existing.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated image embeddings.
|
||||
"""
|
||||
|
||||
# 0. Define commonly used variables
|
||||
device = self._execution_device
|
||||
dtype = next(self.prior.parameters()).dtype
|
||||
self._guidance_scale = guidance_scale
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
images=images,
|
||||
image_embeds=image_embeds,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Encode caption + images
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_embeds_pooled,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image(
|
||||
images=images,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
elif image_embeds is not None:
|
||||
image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)
|
||||
uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled)
|
||||
else:
|
||||
image_embeds_pooled = torch.zeros(
|
||||
batch_size * num_images_per_prompt,
|
||||
1,
|
||||
self.prior.config.clip_image_in_channels,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
uncond_image_embeds_pooled = torch.zeros(
|
||||
batch_size * num_images_per_prompt,
|
||||
1,
|
||||
self.prior.config.clip_image_in_channels,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0)
|
||||
else:
|
||||
image_embeds = image_embeds_pooled
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_encoder_hidden_states = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
|
||||
)
|
||||
text_encoder_pooled = (
|
||||
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
|
||||
if negative_prompt_embeds is not None
|
||||
else prompt_embeds_pooled
|
||||
)
|
||||
|
||||
# 4. Prepare and set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latents
|
||||
latents = self.prepare_latents(
|
||||
batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timesteps = timesteps[:-1]
|
||||
else:
|
||||
if self.scheduler.config.clip_sample:
|
||||
self.scheduler.config.clip_sample = False # disample sample clipping
|
||||
logger.warning(" set `clip_sample` to be False")
|
||||
# 6. Run denoising loop
|
||||
if hasattr(self.scheduler, "betas"):
|
||||
alphas = 1.0 - self.scheduler.betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
else:
|
||||
alphas_cumprod = []
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
if len(alphas_cumprod) > 0:
|
||||
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
|
||||
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
|
||||
else:
|
||||
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
|
||||
else:
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
# 7. Denoise image embeddings
|
||||
predicted_image_embedding = self.prior(
|
||||
sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
||||
timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
|
||||
clip_text_pooled=text_encoder_pooled,
|
||||
clip_text=text_encoder_hidden_states,
|
||||
clip_img=image_embeds,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# 8. Check for classifier free guidance and apply it
|
||||
if self.do_classifier_free_guidance:
|
||||
predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2)
|
||||
predicted_image_embedding = torch.lerp(
|
||||
predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale
|
||||
)
|
||||
|
||||
# 9. Renoise latents to next timestep
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timestep_ratio = t
|
||||
latents = self.scheduler.step(
|
||||
model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator
|
||||
).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if output_type == "np":
|
||||
latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
negative_prompt_embeds = (
|
||||
negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None
|
||||
) # float() as bfloat16-> numpy doesnt work
|
||||
|
||||
if not return_dict:
|
||||
return (latents, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)
|
||||
@@ -1,18 +1,3 @@
|
||||
# Copyright (c) 2023 Dominic Rampas MIT License
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
|
||||
|
||||
|
||||
class ResBlockStageB(nn.Module):
|
||||
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
@@ -349,6 +349,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
text_encoder_hidden_states = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
|
||||
)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
if self.do_classifier_free_guidance
|
||||
else image_embeddings
|
||||
)
|
||||
|
||||
# 3. Determine latent shape of latents
|
||||
latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale)
|
||||
@@ -371,11 +376,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
self._num_timesteps = len(timesteps[:-1])
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
ratio = t.expand(latents.size(0)).to(dtype)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
if self.do_classifier_free_guidance
|
||||
else image_embeddings
|
||||
)
|
||||
# 7. Denoise latents
|
||||
predicted_latents = self.decoder(
|
||||
torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
||||
@@ -423,9 +423,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
latents = self.vqgan.config.scale_factor * latents
|
||||
images = self.vqgan.decode(latents).sample.clamp(0, 1)
|
||||
if output_type == "np":
|
||||
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy()
|
||||
elif output_type == "pil":
|
||||
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy()
|
||||
images = self.numpy_to_pil(images)
|
||||
else:
|
||||
images = latents
|
||||
|
||||
@@ -508,7 +508,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if output_type == "np":
|
||||
latents = latents.cpu().numpy()
|
||||
latents = latents.cpu().float().numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (latents,)
|
||||
|
||||
@@ -752,6 +752,51 @@ class ShapEPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableCascadeCombinedPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableCascadeDecoderPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableCascadePriorPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionAdapterPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/stable_cascade/__init__.py
Normal file
0
tests/pipelines/stable_cascade/__init__.py
Normal file
246
tests/pipelines/stable_cascade/test_stable_cascade_combined.py
Normal file
246
tests/pipelines/stable_cascade/test_stable_cascade_combined.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableCascadeCombinedPipeline
|
||||
params = ["prompt"]
|
||||
batch_params = ["prompt", "negative_prompt"]
|
||||
required_optional_params = [
|
||||
"generator",
|
||||
"height",
|
||||
"width",
|
||||
"latents",
|
||||
"prior_guidance_scale",
|
||||
"decoder_guidance_scale",
|
||||
"negative_prompt",
|
||||
"num_inference_steps",
|
||||
"return_dict",
|
||||
"prior_num_inference_steps",
|
||||
"output_type",
|
||||
]
|
||||
test_xformers_attention = True
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def dummy_prior(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": (128, 128),
|
||||
"num_attention_heads": (2, 2),
|
||||
"down_num_layers_per_block": (1, 1),
|
||||
"up_num_layers_per_block": (1, 1),
|
||||
"clip_image_in_channels": 768,
|
||||
"switch_level": (False,),
|
||||
"clip_text_in_channels": self.text_embedder_hidden_size,
|
||||
"clip_text_pooled_in_channels": self.text_embedder_hidden_size,
|
||||
}
|
||||
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config).eval()
|
||||
|
||||
@property
|
||||
def dummy_vqgan(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"bottleneck_blocks": 1,
|
||||
"num_vq_embeddings": 2,
|
||||
}
|
||||
model = PaellaVQModel(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
@property
|
||||
def dummy_decoder(self):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": (16, 32, 64, 128),
|
||||
"num_attention_heads": (-1, -1, 1, 2),
|
||||
"down_num_layers_per_block": (1, 1, 1, 1),
|
||||
"up_num_layers_per_block": (1, 1, 1, 1),
|
||||
"down_blocks_repeat_mappers": (1, 1, 1, 1),
|
||||
"up_blocks_repeat_mappers": (3, 3, 2, 2),
|
||||
"block_types_per_layer": (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
),
|
||||
"switch_level": None,
|
||||
"clip_text_pooled_in_channels": 32,
|
||||
"dropout": (0.1, 0.1, 0.1, 0.1),
|
||||
}
|
||||
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def get_dummy_components(self):
|
||||
prior = self.dummy_prior
|
||||
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
tokenizer = self.dummy_tokenizer
|
||||
text_encoder = self.dummy_text_encoder
|
||||
decoder = self.dummy_decoder
|
||||
vqgan = self.dummy_vqgan
|
||||
prior_text_encoder = self.dummy_text_encoder
|
||||
prior_tokenizer = self.dummy_tokenizer
|
||||
|
||||
components = {
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"decoder": decoder,
|
||||
"scheduler": scheduler,
|
||||
"vqgan": vqgan,
|
||||
"prior_text_encoder": prior_text_encoder,
|
||||
"prior_tokenizer": prior_tokenizer,
|
||||
"prior_prior": prior,
|
||||
"prior_scheduler": scheduler,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"prior_guidance_scale": 4.0,
|
||||
"decoder_guidance_scale": 4.0,
|
||||
"num_inference_steps": 2,
|
||||
"prior_num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
"height": 128,
|
||||
"width": 128,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_cascade(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.images
|
||||
|
||||
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_gpu
|
||||
def test_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=2e-2)
|
||||
|
||||
@unittest.skip(reason="fp16 not supported")
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference()
|
||||
|
||||
@unittest.skip(reason="no callback test for combined pipeline")
|
||||
def test_callback_inputs(self):
|
||||
super().test_callback_inputs()
|
||||
|
||||
# def test_callback_cfg(self):
|
||||
# pass
|
||||
# pass
|
||||
249
tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
Normal file
249
tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
load_pt,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableCascadeDecoderPipeline
|
||||
params = ["prompt"]
|
||||
batch_params = ["image_embeddings", "prompt", "negative_prompt"]
|
||||
required_optional_params = [
|
||||
"num_images_per_prompt",
|
||||
"num_inference_steps",
|
||||
"latents",
|
||||
"negative_prompt",
|
||||
"guidance_scale",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = False
|
||||
callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def block_out_channels_0(self):
|
||||
return self.time_input_dim
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
return self.time_input_dim * 4
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config).eval()
|
||||
|
||||
@property
|
||||
def dummy_vqgan(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"bottleneck_blocks": 1,
|
||||
"num_vq_embeddings": 2,
|
||||
}
|
||||
model = PaellaVQModel(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
@property
|
||||
def dummy_decoder(self):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": [16, 32, 64, 128],
|
||||
"num_attention_heads": [-1, -1, 1, 2],
|
||||
"down_num_layers_per_block": [1, 1, 1, 1],
|
||||
"up_num_layers_per_block": [1, 1, 1, 1],
|
||||
"down_blocks_repeat_mappers": [1, 1, 1, 1],
|
||||
"up_blocks_repeat_mappers": [3, 3, 2, 2],
|
||||
"block_types_per_layer": [
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
"switch_level": None,
|
||||
"clip_text_pooled_in_channels": 32,
|
||||
"dropout": [0.1, 0.1, 0.1, 0.1],
|
||||
}
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def get_dummy_components(self):
|
||||
decoder = self.dummy_decoder
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
vqgan = self.dummy_vqgan
|
||||
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
components = {
|
||||
"decoder": decoder,
|
||||
"vqgan": vqgan,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"latent_dim_scale": 4.0,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"image_embeddings": torch.ones((1, 4, 4, 4), device=device),
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"guidance_scale": 2.0,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_wuerstchen_decoder(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.images
|
||||
|
||||
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
test_mean_pixel_difference = False
|
||||
|
||||
self._test_attention_slicing_forward_pass(
|
||||
test_max_difference=test_max_difference,
|
||||
test_mean_pixel_difference=test_mean_pixel_difference,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="fp16 not supported")
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_cascade_decoder(self):
|
||||
pipe = StableCascadeDecoderPipeline.from_pretrained(
|
||||
"diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_embedding = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, image_embeddings=image_embedding, num_inference_steps=10, generator=generator
|
||||
).images[0]
|
||||
|
||||
assert image.size == (1024, 1024)
|
||||
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/t2i.png"
|
||||
)
|
||||
|
||||
image_processor = VaeImageProcessor()
|
||||
|
||||
image_np = image_processor.pil_to_numpy(image)
|
||||
expected_image_np = image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=53e-2))
|
||||
308
tests/pipelines/stable_cascade/test_stable_cascade_prior.py
Normal file
308
tests/pipelines/stable_cascade/test_stable_cascade_prior.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_pt,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_prior_lora_layers(unet: nn.Module):
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=unet.config.c,
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableCascadePriorPipeline
|
||||
params = ["prompt"]
|
||||
batch_params = ["prompt", "negative_prompt"]
|
||||
required_optional_params = [
|
||||
"num_images_per_prompt",
|
||||
"generator",
|
||||
"num_inference_steps",
|
||||
"latents",
|
||||
"negative_prompt",
|
||||
"guidance_scale",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = False
|
||||
callback_cfg_params = ["text_encoder_hidden_states"]
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def block_out_channels_0(self):
|
||||
return self.time_input_dim
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
return self.time_input_dim * 4
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config).eval()
|
||||
|
||||
@property
|
||||
def dummy_prior(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": (128, 128),
|
||||
"num_attention_heads": (2, 2),
|
||||
"down_num_layers_per_block": (1, 1),
|
||||
"up_num_layers_per_block": (1, 1),
|
||||
"switch_level": (False,),
|
||||
"clip_image_in_channels": 768,
|
||||
"clip_text_in_channels": self.text_embedder_hidden_size,
|
||||
"clip_text_pooled_in_channels": self.text_embedder_hidden_size,
|
||||
"dropout": (0.1, 0.1),
|
||||
}
|
||||
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def get_dummy_components(self):
|
||||
prior = self.dummy_prior
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
components = {
|
||||
"prior": prior,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"guidance_scale": 4.0,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_wuerstchen_prior(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.image_embeddings
|
||||
|
||||
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, 0, 0, -10:]
|
||||
image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
|
||||
assert image.shape == (1, 16, 24, 24)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
96.139565,
|
||||
-20.213179,
|
||||
-116.40341,
|
||||
-191.57129,
|
||||
39.350136,
|
||||
74.80767,
|
||||
39.782352,
|
||||
-184.67352,
|
||||
-46.426907,
|
||||
168.41783,
|
||||
]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-1)
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
test_mean_pixel_difference = False
|
||||
|
||||
self._test_attention_slicing_forward_pass(
|
||||
test_max_difference=test_max_difference,
|
||||
test_mean_pixel_difference=test_mean_pixel_difference,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="fp16 not supported")
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference()
|
||||
|
||||
def check_if_lora_correctly_set(self, model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_lora_components(self):
|
||||
prior = self.dummy_prior
|
||||
|
||||
prior_lora_config = LoraConfig(
|
||||
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
|
||||
|
||||
lora_components = {
|
||||
"prior_lora_layers": prior_lora_layers,
|
||||
"prior_lora_attn_procs": prior_lora_attn_procs,
|
||||
}
|
||||
|
||||
return prior, prior_lora_config, lora_components
|
||||
|
||||
@require_peft_backend
|
||||
@unittest.skip(reason="no lora support for now")
|
||||
def test_inference_with_prior_lora(self):
|
||||
_, prior_lora_config, _ = self.get_lora_components()
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_no_lora = pipe(**self.get_dummy_inputs(device))
|
||||
image_embed = output_no_lora.image_embeddings
|
||||
self.assertTrue(image_embed.shape == (1, 16, 24, 24))
|
||||
|
||||
pipe.prior.add_adapter(prior_lora_config)
|
||||
self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
|
||||
|
||||
output_lora = pipe(**self.get_dummy_inputs(device))
|
||||
lora_image_embed = output_lora.image_embeddings
|
||||
|
||||
self.assertTrue(image_embed.shape == lora_image_embed.shape)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_cascade_prior(self):
|
||||
pipe = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
output = pipe(prompt, num_inference_steps=10, generator=generator)
|
||||
image_embedding = output.image_embeddings
|
||||
|
||||
expected_image_embedding = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
|
||||
)
|
||||
|
||||
assert image_embedding.shape == (1, 16, 24, 24)
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
image_embedding.cpu().float().numpy(), expected_image_embedding.cpu().float().numpy(), atol=5e-2
|
||||
)
|
||||
)
|
||||
@@ -45,7 +45,6 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
"return_dict",
|
||||
"prior_num_inference_steps",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user