mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
mixed inference for text2img
This commit is contained in:
@@ -648,6 +648,7 @@ if __name__ == "__main__":
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
skip_prk_steps=True,
|
||||
steps_offset=1,
|
||||
)
|
||||
elif args.scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
||||
@@ -668,12 +669,14 @@ if __name__ == "__main__":
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
# Convert the UNet2DConditionModel models.
|
||||
if args.unet_checkpoint_path is not None:
|
||||
# image UNet
|
||||
image_unet_config = create_unet_diffusers_config(IMAGE_UNET_CONFIG)
|
||||
checkpoint = torch.load(args.unet_checkpoint_path)
|
||||
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
@@ -682,11 +685,28 @@ if __name__ == "__main__":
|
||||
image_unet = UNet2DConditionModel(**image_unet_config)
|
||||
image_unet.load_state_dict(converted_image_unet_checkpoint)
|
||||
|
||||
# text UNet
|
||||
text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
|
||||
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
)
|
||||
text_unet = UNet2DConditionModel(**text_unet_config)
|
||||
# TEMP hack to skip converting the 1x1 blocks for the text unet
|
||||
del converted_text_unet_checkpoint["conv_in.weight"]
|
||||
del converted_text_unet_checkpoint["conv_in.bias"]
|
||||
del converted_text_unet_checkpoint["conv_out.weight"]
|
||||
for block in ["down_blocks", "mid_block", "up_blocks"]:
|
||||
for i in range(4):
|
||||
for j in range(3):
|
||||
for module in ["time_emb_proj", "conv1", "norm1", "conv2", "norm2", "conv_shortcut"]:
|
||||
for type in ["weight", "bias"]:
|
||||
if block == "mid_block":
|
||||
key = f"{block}.resnets.{j}.{module}.{type}"
|
||||
else:
|
||||
key = f"{block}.{i}.resnets.{j}.{module}.{type}"
|
||||
if key in converted_text_unet_checkpoint:
|
||||
del converted_text_unet_checkpoint[key]
|
||||
# END TEMP hack
|
||||
text_unet.load_state_dict(converted_text_unet_checkpoint, strict=False)
|
||||
|
||||
# Convert the VAE model.
|
||||
|
||||
@@ -13,21 +13,53 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from functools import reduce
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel, VQModel
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...models.attention import Transformer2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
|
||||
|
||||
class VersatileMixedModel:
|
||||
"""
|
||||
A context managet that swaps the transformer modules between the image and text unet during inference,
|
||||
depending on the latent type and condition type.
|
||||
"""
|
||||
|
||||
def __init__(self, image_unet, text_unet, latent_type, condition_type):
|
||||
self.image_unet = image_unet
|
||||
self.text_unet = text_unet
|
||||
self.latent_type = latent_type
|
||||
self.condition_type = condition_type
|
||||
|
||||
def swap_transformer_modules(self):
|
||||
for name, module in self.image_unet.named_modules():
|
||||
if isinstance(module, Transformer2DModel):
|
||||
parent_name, index = name.rsplit(".", 1)
|
||||
index = int(index)
|
||||
self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
|
||||
self.text_unet.get_submodule(parent_name)[index],
|
||||
self.image_unet.get_submodule(parent_name)[index],
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
if self.latent_type != self.condition_type:
|
||||
self.swap_transformer_modules()
|
||||
return self.image_unet if self.latent_type == "image" else self.text_unet
|
||||
|
||||
def __exit__(self, *exc):
|
||||
# swap the modules back
|
||||
if self.latent_type != self.condition_type:
|
||||
self.swap_transformer_modules()
|
||||
|
||||
|
||||
class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
@@ -51,6 +83,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel
|
||||
image_encoder: CLIPVisionModel
|
||||
image_unet: UNet2DConditionModel
|
||||
text_unet: UNet2DConditionModel
|
||||
vae: Union[VQModel, AutoencoderKL]
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
|
||||
@@ -72,6 +105,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_unet=image_unet,
|
||||
text_unet=text_unet,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
@@ -82,12 +116,6 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
name: module for name, module in text_unet.named_modules() if isinstance(module, Transformer2DModel)
|
||||
}
|
||||
|
||||
# text2img by default
|
||||
for full_name, module in image_unet.named_modules():
|
||||
if isinstance(module, Transformer2DModel):
|
||||
parent_name, name = full_name.rsplit('.', 1)
|
||||
image_unet.get_submodule(parent_name)[name] = self.text_transformer_blocks[name]
|
||||
|
||||
def _encode_prompt(self, prompt, do_classifier_free_guidance):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
@@ -100,7 +128,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
"""
|
||||
|
||||
def _normalize_embeddings(encoder_output):
|
||||
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) # sum == 19677.4570
|
||||
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) # sum == 19677.4570
|
||||
embeds_pooled = encoder_output.text_embeds # sum == 260.2655
|
||||
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
|
||||
return embeds
|
||||
@@ -185,8 +213,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
text_embeddings = self._encode_prompt(prompt, do_classifier_free_guidance)
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.image_unet.in_channels, height // 8, width // 8),
|
||||
generator=generator, device=self.device
|
||||
(batch_size, self.image_unet.in_channels, height // 8, width // 8), generator=generator, device=self.device
|
||||
)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -198,22 +225,22 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
if accepts_eta:
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
t += 1
|
||||
if not do_classifier_free_guidance:
|
||||
latents_input = latents
|
||||
else:
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
with VersatileMixedModel(self.image_unet, self.text_unet, "image", "text") as unet:
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
if not do_classifier_free_guidance:
|
||||
latents_input = latents
|
||||
else:
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.image_unet(latents_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
# predict the noise residual
|
||||
noise_pred = unet(latents_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
@@ -228,115 +255,3 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
|
||||
class VDMixedModelWrapper(nn.Module):
|
||||
def __init__(self, image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel):
|
||||
super().__init__()
|
||||
self.image_unet = image_unet
|
||||
self.text_unet = text_unet
|
||||
self.time_embedding = self.unet_image.time_embedding
|
||||
self.time_proj = self.unet_image.time_proj
|
||||
|
||||
def embed_imesteps(self, timesteps, sample):
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
t_emb = self.time_proj(timesteps)
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
return emb
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, latents_type="image", condition_type="text", return_dict: bool = True):
|
||||
default_overall_up_factor = 2 ** self.image_unet.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
forward_upsample_size = True
|
||||
|
||||
# 1. time
|
||||
emb = self.embed_imesteps(timestep, sample)
|
||||
|
||||
# 2. pre-process
|
||||
if latents_type == "image":
|
||||
sample = self.image_unet.conv_in(sample)
|
||||
elif latents_type == "text":
|
||||
sample = self.text_unet.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
|
||||
|
||||
def mixed_forward(self, image_module, text_module, hidden_state, timesteps_emb, condition, latents_type="image", condition_type="text"):
|
||||
for ilayer, tlayer in zip(image_module, text_module):
|
||||
if isinstance(ilayer, SpatialTransformer) and condition_type == 'image':
|
||||
hidden_state = ilayer(hidden_state, condition)
|
||||
elif isinstance(ilayer, SpatialTransformer) and condition_type == 'text':
|
||||
hidden_state = tlayer(hidden_state, condition)
|
||||
elif latents_type == 'image':
|
||||
hidden_state = ilayer(hidden_state)
|
||||
elif latents_type == 'text':
|
||||
hidden_state = tlayer(hidden_state)
|
||||
else:
|
||||
raise ValueError(f"latents_type {latents_type} and condition_type {condition_type} not supported")
|
||||
return hidden_state
|
||||
|
||||
|
||||
|
||||
@@ -37,9 +37,9 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_text2img(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("scripts/vd-diffusers")
|
||||
pipe.to(torch_device)
|
||||
#pipe.set_progress_bar_config(disable=None)
|
||||
# pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ"
|
||||
prompt = "A painting of a squirrel eating a burger "
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
||||
@@ -47,6 +47,6 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user