1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix styling issues (#5699)

* up

* up

* up

* Empty-Commit

* fix keyword argument call.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Patrick von Platen
2023-11-08 13:19:06 +01:00
committed by GitHub
parent 78be400761
commit 17528afcba
3 changed files with 24 additions and 34 deletions

View File

@@ -206,17 +206,15 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
prompt_embeds = prior_text_encoder_output.text_embeds
prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
text_enc_hid_states = prior_text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -235,9 +233,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
)
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
uncond_prior_text_encoder_hidden_states = (
negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
)
uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -245,11 +241,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
1, num_images_per_prompt, 1
)
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
seq_len = uncond_text_enc_hid_states.shape[1]
uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -260,13 +254,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prior_text_encoder_hidden_states = torch.cat(
[uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
)
text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return prompt_embeds, prior_text_encoder_hidden_states, text_mask
return prompt_embeds, text_enc_hid_states, text_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(

View File

@@ -156,15 +156,15 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_enc_hid_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -181,7 +181,7 @@ class UnCLIPPipeline(DiffusionPipeline):
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -189,9 +189,9 @@ class UnCLIPPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
seq_len = uncond_text_enc_hid_states.shape[1]
uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -202,11 +202,11 @@ class UnCLIPPipeline(DiffusionPipeline):
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return prompt_embeds, text_encoder_hidden_states, text_mask
return prompt_embeds, text_enc_hid_states, text_mask
@torch.no_grad()
def __call__(
@@ -293,7 +293,7 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
@@ -321,7 +321,7 @@ class UnCLIPPipeline(DiffusionPipeline):
latent_model_input,
timestep=t,
proj_embedding=prompt_embeds,
encoder_hidden_states=text_encoder_hidden_states,
encoder_hidden_states=text_enc_hid_states,
attention_mask=text_mask,
).predicted_image_embedding
@@ -352,10 +352,10 @@ class UnCLIPPipeline(DiffusionPipeline):
# decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states,
text_encoder_hidden_states=text_enc_hid_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
@@ -377,7 +377,7 @@ class UnCLIPPipeline(DiffusionPipeline):
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
text_enc_hid_states.dtype,
device,
generator,
decoder_latents,
@@ -391,7 +391,7 @@ class UnCLIPPipeline(DiffusionPipeline):
noise_pred = self.decoder(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=text_encoder_hidden_states,
encoder_hidden_states=text_enc_hid_states,
class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask,
).sample

View File

@@ -1494,7 +1494,6 @@ class ResnetBlockFlat(nn.Module):
return output_tensor
# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class DownBlockFlat(nn.Module):
def __init__(
self,
@@ -1583,7 +1582,6 @@ class DownBlockFlat(nn.Module):
return hidden_states, output_states
# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class CrossAttnDownBlockFlat(nn.Module):
def __init__(
self,