From 17528afcba583fa0e49504208af33b1a62ff1294 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Nov 2023 13:19:06 +0100 Subject: [PATCH] Fix styling issues (#5699) * up * up * up * Empty-Commit * fix keyword argument call. --------- Co-authored-by: Sayak Paul --- .../pipeline_stable_unclip.py | 26 ++++++---------- .../pipelines/unclip/pipeline_unclip.py | 30 +++++++++---------- .../versatile_diffusion/modeling_text_unet.py | 2 -- 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index c81dd85f0e..eb4542888c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -206,17 +206,15 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) prompt_embeds = prior_text_encoder_output.text_embeds - prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state + text_enc_hid_states = prior_text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] - prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1] + prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) - prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) + text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: @@ -235,9 +233,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ) negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds - uncond_prior_text_encoder_hidden_states = ( - negative_prompt_embeds_prior_text_encoder_output.last_hidden_state - ) + uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method @@ -245,11 +241,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) - seq_len = uncond_prior_text_encoder_hidden_states.shape[1] - uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat( - 1, num_images_per_prompt, 1 - ) - uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view( + seq_len = uncond_text_enc_hid_states.shape[1] + uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) + uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) @@ -260,13 +254,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - prior_text_encoder_hidden_states = torch.cat( - [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states] - ) + text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) - return prompt_embeds, prior_text_encoder_hidden_states, text_mask + return prompt_embeds, text_enc_hid_states, text_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index c4a25c865d..7bebed73c1 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -156,15 +156,15 @@ class UnCLIPPipeline(DiffusionPipeline): text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_enc_hid_states = text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] - prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: @@ -181,7 +181,7 @@ class UnCLIPPipeline(DiffusionPipeline): negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds - uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method @@ -189,9 +189,9 @@ class UnCLIPPipeline(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) - seq_len = uncond_text_encoder_hidden_states.shape[1] - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + seq_len = uncond_text_enc_hid_states.shape[1] + uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) + uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) @@ -202,11 +202,11 @@ class UnCLIPPipeline(DiffusionPipeline): # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) - return prompt_embeds, text_encoder_hidden_states, text_mask + return prompt_embeds, text_enc_hid_states, text_mask @torch.no_grad() def __call__( @@ -293,7 +293,7 @@ class UnCLIPPipeline(DiffusionPipeline): do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 - prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) @@ -321,7 +321,7 @@ class UnCLIPPipeline(DiffusionPipeline): latent_model_input, timestep=t, proj_embedding=prompt_embeds, - encoder_hidden_states=text_encoder_hidden_states, + encoder_hidden_states=text_enc_hid_states, attention_mask=text_mask, ).predicted_image_embedding @@ -352,10 +352,10 @@ class UnCLIPPipeline(DiffusionPipeline): # decoder - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + text_enc_hid_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, prompt_embeds=prompt_embeds, - text_encoder_hidden_states=text_encoder_hidden_states, + text_encoder_hidden_states=text_enc_hid_states, do_classifier_free_guidance=do_classifier_free_guidance, ) @@ -377,7 +377,7 @@ class UnCLIPPipeline(DiffusionPipeline): decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, + text_enc_hid_states.dtype, device, generator, decoder_latents, @@ -391,7 +391,7 @@ class UnCLIPPipeline(DiffusionPipeline): noise_pred = self.decoder( sample=latent_model_input, timestep=t, - encoder_hidden_states=text_encoder_hidden_states, + encoder_hidden_states=text_enc_hid_states, class_labels=additive_clip_time_embeddings, attention_mask=decoder_text_mask, ).sample diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 32147ffa45..60ea3d814b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1494,7 +1494,6 @@ class ResnetBlockFlat(nn.Module): return output_tensor -# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim class DownBlockFlat(nn.Module): def __init__( self, @@ -1583,7 +1582,6 @@ class DownBlockFlat(nn.Module): return hidden_states, output_states -# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim class CrossAttnDownBlockFlat(nn.Module): def __init__( self,