mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
addressed PR comments
This commit is contained in:
@@ -32,5 +32,5 @@ The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sa
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## SanaControlNetPipelineOutput
|
||||
[[autodoc]] pipelines.controlnet_sana.SanaControlNetPipelineOutput
|
||||
## SanaPipelineOutput
|
||||
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
|
||||
@@ -354,9 +354,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
@@ -928,23 +926,23 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
latent_model_input.to(dtype=transformer_dtype),
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
timestep=timestep.to(dtype=transformer_dtype),
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)[0]
|
||||
|
||||
@@ -261,6 +261,66 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_sequence_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -309,6 +369,11 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
|
||||
@@ -333,37 +398,18 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
select_index = [0] + list(range(-max_length + 1, 0))
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0][:, select_index]
|
||||
prompt_embeds = prompt_embeds[:, select_index]
|
||||
prompt_attention_mask = prompt_attention_mask[:, select_index]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -373,29 +419,17 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=False,
|
||||
)
|
||||
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
@@ -947,9 +981,8 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
control_image = self.vae.encode(control_image).latent
|
||||
control_image = control_image * self.vae.config.scaling_factor
|
||||
|
||||
else:
|
||||
assert False
|
||||
raise ValueError("`controlnet` must be of type `SanaControlNetModel`.")
|
||||
|
||||
# 5. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
@@ -993,7 +1026,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
latent_model_input.to(dtype=controlnet_dtype),
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
timestep=timestep.to(dtype=controlnet_dtype),
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
controlnet_cond=control_image,
|
||||
@@ -1005,7 +1038,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
latent_model_input.to(dtype=transformer_dtype),
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
timestep=timestep.to(dtype=transformer_dtype),
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
controlnet_block_samples=controlnet_block_samples.to(dtype=transformer_dtype),
|
||||
|
||||
@@ -295,9 +295,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
@@ -806,13 +804,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
latents_model_input = latents / self.scheduler.config.sigma_data
|
||||
|
||||
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
|
||||
@@ -821,15 +820,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
latent_model_input = latents_model_input * torch.sqrt(
|
||||
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
|
||||
)
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
latent_model_input.to(dtype=transformer_dtype),
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
guidance=guidance,
|
||||
timestep=scm_timestep,
|
||||
timestep=scm_timestep.to(dtype=transformer_dtype),
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)[0]
|
||||
|
||||
Reference in New Issue
Block a user