From 577651786fbf498cff4025f2b8c29944f2c8a020 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Wed, 9 Apr 2025 12:07:50 +0530 Subject: [PATCH] addressed PR comments --- .../en/api/pipelines/controlnet_sana.md | 4 +- src/diffusers/pipelines/sana/pipeline_sana.py | 14 +- .../sana/pipeline_sana_controlnet.py | 135 +++++++++++------- .../pipelines/sana/pipeline_sana_sprint.py | 14 +- 4 files changed, 98 insertions(+), 69 deletions(-) diff --git a/docs/source/en/api/pipelines/controlnet_sana.md b/docs/source/en/api/pipelines/controlnet_sana.md index 67ec882d68..fa04591532 100644 --- a/docs/source/en/api/pipelines/controlnet_sana.md +++ b/docs/source/en/api/pipelines/controlnet_sana.md @@ -32,5 +32,5 @@ The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sa - all - __call__ -## SanaControlNetPipelineOutput -[[autodoc]] pipelines.controlnet_sana.SanaControlNetPipelineOutput +## SanaPipelineOutput +[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput \ No newline at end of file diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 0f0ae5f197..8998e7bfb7 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -354,9 +354,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): if device is None: device = self._execution_device - if self.transformer is not None: - dtype = self.transformer.dtype - elif self.text_encoder is not None: + if self.text_encoder is not None: dtype = self.text_encoder.dtype else: dtype = None @@ -928,23 +926,23 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + transformer_dtype = self.transformer.dtype with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), encoder_attention_mask=prompt_attention_mask, - timestep=timestep, + timestep=timestep.to(dtype=transformer_dtype), return_dict=False, attention_kwargs=self.attention_kwargs, )[0] diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index f363eb6058..8ff1a25271 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -261,6 +261,66 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -309,6 +369,11 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): if device is None: device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): @@ -333,37 +398,18 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): select_index = [0] + list(range(-max_length + 1, 0)) if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) - - # prepare complex human instruction - if not complex_human_instruction: - max_length_all = max_length - else: - chi_prompt = "\n".join(complex_human_instruction) - prompt = [chi_prompt + p for p in prompt] - num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) - max_length_all = num_chi_prompt_tokens + max_length - 2 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length_all, - truncation=True, - add_special_tokens=True, - return_tensors="pt", + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) - - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0][:, select_index] + prompt_embeds = prompt_embeds[:, select_index] prompt_attention_mask = prompt_attention_mask[:, select_index] - dtype = self.text_encoder.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -373,29 +419,17 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt - uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, ) - negative_prompt_attention_mask = uncond_input.attention_mask - negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask - ) - negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] @@ -947,9 +981,8 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): control_image = self.vae.encode(control_image).latent control_image = control_image * self.vae.config.scaling_factor - else: - assert False + raise ValueError("`controlnet` must be of type `SanaControlNetModel`.") # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -993,7 +1026,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): latent_model_input.to(dtype=controlnet_dtype), encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype), encoder_attention_mask=prompt_attention_mask, - timestep=timestep, + timestep=timestep.to(dtype=controlnet_dtype), return_dict=False, attention_kwargs=self.attention_kwargs, controlnet_cond=control_image, @@ -1005,7 +1038,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): latent_model_input.to(dtype=transformer_dtype), encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), encoder_attention_mask=prompt_attention_mask, - timestep=timestep, + timestep=timestep.to(dtype=transformer_dtype), return_dict=False, attention_kwargs=self.attention_kwargs, controlnet_block_samples=controlnet_block_samples.to(dtype=transformer_dtype), diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 9b3acbb1cb..87bf40ce5e 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -295,9 +295,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): if device is None: device = self._execution_device - if self.transformer is not None: - dtype = self.transformer.dtype - elif self.text_encoder is not None: + if self.text_encoder is not None: dtype = self.text_encoder.dtype else: dtype = None @@ -806,13 +804,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + transformer_dtype = self.transformer.dtype with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype) + timestep = t.expand(latents.shape[0]) latents_model_input = latents / self.scheduler.config.sigma_data scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) @@ -821,15 +820,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): latent_model_input = latents_model_input * torch.sqrt( scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2 ) - latent_model_input = latent_model_input.to(prompt_embeds.dtype) # predict noise model_output noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), encoder_attention_mask=prompt_attention_mask, guidance=guidance, - timestep=scm_timestep, + timestep=scm_timestep.to(dtype=transformer_dtype), return_dict=False, attention_kwargs=self.attention_kwargs, )[0]