1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

mixed inference for text2img

This commit is contained in:
anton-l
2022-11-16 17:23:29 +01:00
parent 53f080f17a
commit b5778e0ff3
3 changed files with 76 additions and 141 deletions

View File

@@ -648,6 +648,7 @@ if __name__ == "__main__":
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
steps_offset=1,
)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
@@ -668,12 +669,14 @@ if __name__ == "__main__":
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model.
# Convert the UNet2DConditionModel models.
if args.unet_checkpoint_path is not None:
# image UNet
image_unet_config = create_unet_diffusers_config(IMAGE_UNET_CONFIG)
checkpoint = torch.load(args.unet_checkpoint_path)
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
@@ -682,11 +685,28 @@ if __name__ == "__main__":
image_unet = UNet2DConditionModel(**image_unet_config)
image_unet.load_state_dict(converted_image_unet_checkpoint)
# text UNet
text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
)
text_unet = UNet2DConditionModel(**text_unet_config)
# TEMP hack to skip converting the 1x1 blocks for the text unet
del converted_text_unet_checkpoint["conv_in.weight"]
del converted_text_unet_checkpoint["conv_in.bias"]
del converted_text_unet_checkpoint["conv_out.weight"]
for block in ["down_blocks", "mid_block", "up_blocks"]:
for i in range(4):
for j in range(3):
for module in ["time_emb_proj", "conv1", "norm1", "conv2", "norm2", "conv_shortcut"]:
for type in ["weight", "bias"]:
if block == "mid_block":
key = f"{block}.resnets.{j}.{module}.{type}"
else:
key = f"{block}.{i}.resnets.{j}.{module}.{type}"
if key in converted_text_unet_checkpoint:
del converted_text_unet_checkpoint[key]
# END TEMP hack
text_unet.load_state_dict(converted_text_unet_checkpoint, strict=False)
# Convert the VAE model.

View File

@@ -13,21 +13,53 @@
# limitations under the License.
import inspect
from functools import reduce
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel, VQModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...models.attention import Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
class VersatileMixedModel:
"""
A context managet that swaps the transformer modules between the image and text unet during inference,
depending on the latent type and condition type.
"""
def __init__(self, image_unet, text_unet, latent_type, condition_type):
self.image_unet = image_unet
self.text_unet = text_unet
self.latent_type = latent_type
self.condition_type = condition_type
def swap_transformer_modules(self):
for name, module in self.image_unet.named_modules():
if isinstance(module, Transformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
self.text_unet.get_submodule(parent_name)[index],
self.image_unet.get_submodule(parent_name)[index],
)
def __enter__(self):
if self.latent_type != self.condition_type:
self.swap_transformer_modules()
return self.image_unet if self.latent_type == "image" else self.text_unet
def __exit__(self, *exc):
# swap the modules back
if self.latent_type != self.condition_type:
self.swap_transformer_modules()
class VersatileDiffusionPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
@@ -51,6 +83,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel
image_encoder: CLIPVisionModel
image_unet: UNet2DConditionModel
text_unet: UNet2DConditionModel
vae: Union[VQModel, AutoencoderKL]
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
@@ -72,6 +105,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
scheduler=scheduler,
)
@@ -82,12 +116,6 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
name: module for name, module in text_unet.named_modules() if isinstance(module, Transformer2DModel)
}
# text2img by default
for full_name, module in image_unet.named_modules():
if isinstance(module, Transformer2DModel):
parent_name, name = full_name.rsplit('.', 1)
image_unet.get_submodule(parent_name)[name] = self.text_transformer_blocks[name]
def _encode_prompt(self, prompt, do_classifier_free_guidance):
r"""
Encodes the prompt into text encoder hidden states.
@@ -100,7 +128,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
"""
def _normalize_embeddings(encoder_output):
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) # sum == 19677.4570
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) # sum == 19677.4570
embeds_pooled = encoder_output.text_embeds # sum == 260.2655
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
return embeds
@@ -185,8 +213,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
text_embeddings = self._encode_prompt(prompt, do_classifier_free_guidance)
latents = torch.randn(
(batch_size, self.image_unet.in_channels, height // 8, width // 8),
generator=generator, device=self.device
(batch_size, self.image_unet.in_channels, height // 8, width // 8), generator=generator, device=self.device
)
self.scheduler.set_timesteps(num_inference_steps)
@@ -198,22 +225,22 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
if accepts_eta:
extra_kwargs["eta"] = eta
for t in self.progress_bar(self.scheduler.timesteps):
t += 1
if not do_classifier_free_guidance:
latents_input = latents
else:
latents_input = torch.cat([latents] * 2)
with VersatileMixedModel(self.image_unet, self.text_unet, "image", "text") as unet:
for t in self.progress_bar(self.scheduler.timesteps):
if not do_classifier_free_guidance:
latents_input = latents
else:
latents_input = torch.cat([latents] * 2)
# predict the noise residual
noise_pred = self.image_unet(latents_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# predict the noise residual
noise_pred = unet(latents_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
@@ -228,115 +255,3 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
return (image,)
return ImagePipelineOutput(images=image)
class VDMixedModelWrapper(nn.Module):
def __init__(self, image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel):
super().__init__()
self.image_unet = image_unet
self.text_unet = text_unet
self.time_embedding = self.unet_image.time_embedding
self.time_proj = self.unet_image.time_proj
def embed_imesteps(self, timesteps, sample):
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
return emb
def forward(self, sample, timestep, encoder_hidden_states, latents_type="image", condition_type="text", return_dict: bool = True):
default_overall_up_factor = 2 ** self.image_unet.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
forward_upsample_size = True
# 1. time
emb = self.embed_imesteps(timestep, sample)
# 2. pre-process
if latents_type == "image":
sample = self.image_unet.conv_in(sample)
elif latents_type == "text":
sample = self.text_unet.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)
def mixed_forward(self, image_module, text_module, hidden_state, timesteps_emb, condition, latents_type="image", condition_type="text"):
for ilayer, tlayer in zip(image_module, text_module):
if isinstance(ilayer, SpatialTransformer) and condition_type == 'image':
hidden_state = ilayer(hidden_state, condition)
elif isinstance(ilayer, SpatialTransformer) and condition_type == 'text':
hidden_state = tlayer(hidden_state, condition)
elif latents_type == 'image':
hidden_state = ilayer(hidden_state)
elif latents_type == 'text':
hidden_state = tlayer(hidden_state)
else:
raise ValueError(f"latents_type {latents_type} and condition_type {condition_type} not supported")
return hidden_state

View File

@@ -37,9 +37,9 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_inference_text2img(self):
pipe = VersatileDiffusionPipeline.from_pretrained("scripts/vd-diffusers")
pipe.to(torch_device)
#pipe.set_progress_bar_config(disable=None)
# pipe.set_progress_bar_config(disable=None)
prompt = "a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ"
prompt = "A painting of a squirrel eating a burger "
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
@@ -47,6 +47,6 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2