mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix styling issues (#5699)
* up * up * up * Empty-Commit * fix keyword argument call. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
78be400761
commit
17528afcba
@@ -206,17 +206,15 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
|
||||
|
||||
prompt_embeds = prior_text_encoder_output.text_embeds
|
||||
prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
|
||||
text_enc_hid_states = prior_text_encoder_output.last_hidden_state
|
||||
|
||||
else:
|
||||
batch_size = text_model_output[0].shape[0]
|
||||
prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
|
||||
prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
|
||||
text_mask = text_attention_mask
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -235,9 +233,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
|
||||
uncond_prior_text_encoder_hidden_states = (
|
||||
negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
|
||||
)
|
||||
uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
@@ -245,11 +241,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
|
||||
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
|
||||
seq_len = uncond_text_enc_hid_states.shape[1]
|
||||
uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
@@ -260,13 +254,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
prior_text_encoder_hidden_states = torch.cat(
|
||||
[uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
|
||||
)
|
||||
text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return prompt_embeds, prior_text_encoder_hidden_states, text_mask
|
||||
return prompt_embeds, text_enc_hid_states, text_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -156,15 +156,15 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
prompt_embeds = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
text_enc_hid_states = text_encoder_output.last_hidden_state
|
||||
|
||||
else:
|
||||
batch_size = text_model_output[0].shape[0]
|
||||
prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
|
||||
prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
|
||||
text_mask = text_attention_mask
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -181,7 +181,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
||||
uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
@@ -189,9 +189,9 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
seq_len = uncond_text_enc_hid_states.shape[1]
|
||||
uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
@@ -202,11 +202,11 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
||||
text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return prompt_embeds, text_encoder_hidden_states, text_mask
|
||||
return prompt_embeds, text_enc_hid_states, text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
@@ -293,7 +293,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
|
||||
|
||||
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
|
||||
)
|
||||
|
||||
@@ -321,7 +321,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=prompt_embeds,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
encoder_hidden_states=text_enc_hid_states,
|
||||
attention_mask=text_mask,
|
||||
).predicted_image_embedding
|
||||
|
||||
@@ -352,10 +352,10 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
# decoder
|
||||
|
||||
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
|
||||
text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_encoder_hidden_states=text_encoder_hidden_states,
|
||||
text_encoder_hidden_states=text_enc_hid_states,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
@@ -377,7 +377,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
text_enc_hid_states.dtype,
|
||||
device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
@@ -391,7 +391,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
noise_pred = self.decoder(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
encoder_hidden_states=text_enc_hid_states,
|
||||
class_labels=additive_clip_time_embeddings,
|
||||
attention_mask=decoder_text_mask,
|
||||
).sample
|
||||
|
||||
@@ -1494,7 +1494,6 @@ class ResnetBlockFlat(nn.Module):
|
||||
return output_tensor
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
|
||||
class DownBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1583,7 +1582,6 @@ class DownBlockFlat(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
|
||||
class CrossAttnDownBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user