diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py index 27fe7c1ffa..77191e6cb6 100644 --- a/scripts/convert_versatile_diffusion_to_diffusers.py +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -313,7 +313,7 @@ def create_vae_diffusers_config(vae_params): return config -def create_diffusers_schedular(original_config): +def create_diffusers_scheduler(original_config): schedular = DDIMScheduler( num_train_timesteps=original_config.model.params.timesteps, beta_start=original_config.model.params.linear_start, @@ -323,16 +323,6 @@ def create_diffusers_schedular(original_config): return schedular -def create_ldm_bert_config(original_config): - bert_params = original_config.model.parms.cond_stage_config.params - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): """ Takes a state dict and a config, and returns a converted checkpoint. @@ -503,7 +493,7 @@ def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): return new_checkpoint -def convert_ldm_vae_checkpoint(checkpoint, config): +def convert_vd_vae_checkpoint(checkpoint, config): # extract state dict for VAE vae_state_dict = {} keys = list(checkpoint.keys()) @@ -608,72 +598,6 @@ def convert_ldm_vae_checkpoint(checkpoint, config): return new_checkpoint -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) - - return text_model - - if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -758,18 +682,18 @@ if __name__ == "__main__": image_unet = UNet2DConditionModel(**image_unet_config) image_unet.load_state_dict(converted_image_unet_checkpoint) - # text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG) - # converted_text_unet_checkpoint = convert_vd_unet_checkpoint( - # checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema - # ) - # text_unet = UNet2DConditionModel(**text_unet_config) - # text_unet.load_state_dict(converted_text_unet_checkpoint) + text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG) + converted_text_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema + ) + text_unet = UNet2DConditionModel(**text_unet_config) + text_unet.load_state_dict(converted_text_unet_checkpoint, strict=False) # Convert the VAE model. if args.vae_checkpoint_path is not None: vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG) checkpoint = torch.load(args.vae_checkpoint_path) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) @@ -786,7 +710,7 @@ if __name__ == "__main__": text_encoder=text_encoder, image_encoder=image_encoder, image_unet=image_unet, - # text_unet=text_unet, + text_unet=text_unet, vae=vae, ) pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index 3e56e4b57e..09835e4a0d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -21,7 +21,9 @@ import torch.utils.checkpoint from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel -from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...models import AutoencoderKL, UNet2DConditionModel, VQModel +from ...models.unet_2d_condition import UNet2DConditionOutput +from ...models.attention import Transformer2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -59,6 +61,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, image_encoder: CLIPVisionModel, image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, vae: Union[VQModel, AutoencoderKL], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): @@ -72,6 +75,18 @@ class VersatileDiffusionPipeline(DiffusionPipeline): vae=vae, scheduler=scheduler, ) + self.image_transformer_blocks = { + name: module for name, module in image_unet.named_modules() if isinstance(module, Transformer2DModel) + } + self.text_transformer_blocks = { + name: module for name, module in text_unet.named_modules() if isinstance(module, Transformer2DModel) + } + + # text2img by default + for full_name, module in image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, name = full_name.rsplit('.', 1) + image_unet.get_submodule(parent_name)[name] = self.text_transformer_blocks[name] def _encode_prompt(self, prompt, do_classifier_free_guidance): r""" @@ -85,8 +100,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): """ def _normalize_embeddings(encoder_output): - embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) - embeds_pooled = encoder_output.text_embeds + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) # sum == 19677.4570 + embeds_pooled = encoder_output.text_embeds # sum == 260.2655 embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds @@ -171,9 +186,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline): latents = torch.randn( (batch_size, self.image_unet.in_channels, height // 8, width // 8), - generator=generator, + generator=generator, device=self.device ) - latents = latents.to(self.device) self.scheduler.set_timesteps(num_inference_steps) @@ -185,6 +199,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): + t += 1 if not do_classifier_free_guidance: latents_input = latents else: @@ -213,3 +228,115 @@ class VersatileDiffusionPipeline(DiffusionPipeline): return (image,) return ImagePipelineOutput(images=image) + + +class VDMixedModelWrapper(nn.Module): + def __init__(self, image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel): + super().__init__() + self.image_unet = image_unet + self.text_unet = text_unet + self.time_embedding = self.unet_image.time_embedding + self.time_proj = self.unet_image.time_proj + + def embed_imesteps(self, timesteps, sample): + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + return emb + + def forward(self, sample, timestep, encoder_hidden_states, latents_type="image", condition_type="text", return_dict: bool = True): + default_overall_up_factor = 2 ** self.image_unet.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + forward_upsample_size = True + + # 1. time + emb = self.embed_imesteps(timestep, sample) + + # 2. pre-process + if latents_type == "image": + sample = self.image_unet.conv_in(sample) + elif latents_type == "text": + sample = self.text_unet.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + + + def mixed_forward(self, image_module, text_module, hidden_state, timesteps_emb, condition, latents_type="image", condition_type="text"): + for ilayer, tlayer in zip(image_module, text_module): + if isinstance(ilayer, SpatialTransformer) and condition_type == 'image': + hidden_state = ilayer(hidden_state, condition) + elif isinstance(ilayer, SpatialTransformer) and condition_type == 'text': + hidden_state = tlayer(hidden_state, condition) + elif latents_type == 'image': + hidden_state = ilayer(hidden_state) + elif latents_type == 'text': + hidden_state = tlayer(hidden_state) + else: + raise ValueError(f"latents_type {latents_type} and condition_type {condition_type} not supported") + return hidden_state + + diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py index e0f7618034..be6c826af1 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py @@ -37,10 +37,10 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase): def test_inference_text2img(self): pipe = VersatileDiffusionPipeline.from_pretrained("scripts/vd-diffusers") pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + #pipe.set_progress_bar_config(disable=None) - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) + prompt = "a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ" + generator = torch.Generator(device=torch_device).manual_seed(0) image = pipe( [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" ).images