1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

addressed PR comments

This commit is contained in:
ishan-modi
2025-04-09 12:07:50 +05:30
parent e6ba267225
commit 577651786f
4 changed files with 98 additions and 69 deletions

View File

@@ -32,5 +32,5 @@ The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sa
- all
- __call__
## SanaControlNetPipelineOutput
[[autodoc]] pipelines.controlnet_sana.SanaControlNetPipelineOutput
## SanaPipelineOutput
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput

View File

@@ -354,9 +354,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
@@ -928,23 +926,23 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
transformer_dtype = self.transformer.dtype
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
latent_model_input.to(dtype=transformer_dtype),
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
timestep=timestep.to(dtype=transformer_dtype),
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]

View File

@@ -261,6 +261,66 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_sequence_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -309,6 +369,11 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if device is None:
device = self._execution_device
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
@@ -333,37 +398,18 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
select_index = [0] + list(range(-max_length + 1, 0))
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0][:, select_index]
prompt_embeds = prompt_embeds[:, select_index]
prompt_attention_mask = prompt_attention_mask[:, select_index]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -373,29 +419,17 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=negative_prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=False,
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
@@ -947,9 +981,8 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
control_image = self.vae.encode(control_image).latent
control_image = control_image * self.vae.config.scaling_factor
else:
assert False
raise ValueError("`controlnet` must be of type `SanaControlNetModel`.")
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
@@ -993,7 +1026,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latent_model_input.to(dtype=controlnet_dtype),
encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
timestep=timestep.to(dtype=controlnet_dtype),
return_dict=False,
attention_kwargs=self.attention_kwargs,
controlnet_cond=control_image,
@@ -1005,7 +1038,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latent_model_input.to(dtype=transformer_dtype),
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
timestep=timestep.to(dtype=transformer_dtype),
return_dict=False,
attention_kwargs=self.attention_kwargs,
controlnet_block_samples=controlnet_block_samples.to(dtype=transformer_dtype),

View File

@@ -295,9 +295,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
@@ -806,13 +804,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
transformer_dtype = self.transformer.dtype
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
timestep = t.expand(latents.shape[0])
latents_model_input = latents / self.scheduler.config.sigma_data
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
@@ -821,15 +820,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latent_model_input = latents_model_input * torch.sqrt(
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
latent_model_input.to(dtype=transformer_dtype),
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
guidance=guidance,
timestep=scm_timestep,
timestep=scm_timestep.to(dtype=transformer_dtype),
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]