diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 11d4fd0f06..b05ca1f5ea 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -245,13 +245,13 @@ class PhotonPipeline( model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents"] - _optional_components = [] + _optional_components = ["vae"] def __init__( self, transformer: PhotonTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, - text_encoder: Union[T5GemmaEncoder], + text_encoder: T5GemmaEncoder, tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None, default_sample_size: Optional[int] = DEFAULT_RESOLUTION, @@ -330,6 +330,11 @@ class PhotonPipeline( """Compatibility property that returns spatial compression ratio.""" return getattr(self.vae, "spatial_compression_ratio", 8) + @property + def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled based on guidance scale.""" + return self._guidance_scale > 1.0 + def prepare_latents( self, batch_size: int, @@ -353,49 +358,67 @@ class PhotonPipeline( latents = latents.to(device) return latents - def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device): + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + ): """Encode text prompt using standard text encoder and tokenizer.""" if isinstance(prompt, str): prompt = [prompt] - return self._encode_prompt_standard(prompt, device) - - def _encode_prompt_standard(self, prompt: List[str], device: torch.device): - """Encode prompt using standard text encoder and tokenizer with batch processing.""" - # Clean text using modular preprocessor - cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt] - cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt] - - all_prompts = cleaned_prompts + cleaned_uncond_prompts + return self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) + def _tokenize_prompts(self, prompts: List[str], device: torch.device): + """Tokenize and clean prompts.""" + cleaned = [self.text_preprocessor.clean_text(text) for text in prompts] tokens = self.tokenizer( - all_prompts, + cleaned, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_attention_mask=True, return_tensors="pt", ) + return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device) - input_ids = tokens["input_ids"].to(device) - attention_mask = tokens["attention_mask"].bool().to(device) + def _encode_prompt_standard( + self, + prompt: List[str], + device: torch.device, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + ): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + batch_size = len(prompt) + + if do_classifier_free_guidance: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + prompts_to_encode = negative_prompt + prompt + else: + prompts_to_encode = prompt + + input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device) with torch.no_grad(): - emb = self.text_encoder( + embeddings = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, - ) + )["last_hidden_state"] - all_embeddings = emb["last_hidden_state"] - - # Split back into conditional and unconditional - batch_size = len(prompt) - text_embeddings = all_embeddings[:batch_size] - uncond_text_embeddings = all_embeddings[batch_size:] - - cross_attn_mask = attention_mask[:batch_size] - uncond_cross_attn_mask = attention_mask[batch_size:] + if do_classifier_free_guidance: + uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0) + uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0) + else: + text_embeddings = embeddings + cross_attn_mask = attention_mask + uncond_text_embeddings = None + uncond_cross_attn_mask = None return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask @@ -534,9 +557,11 @@ class PhotonPipeline( device = self._execution_device + self._guidance_scale = guidance_scale + # 2. Encode input prompt text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( - prompt, device + prompt, device, do_classifier_free_guidance=self.do_classifier_free_guidance ) # 3. Prepare timesteps @@ -572,17 +597,22 @@ class PhotonPipeline( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # Duplicate latents for CFG - latents_in = torch.cat([latents, latents], dim=0) - - # Cross-attention batch (uncond, cond) - ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) - ca_mask = None - if cross_attn_mask is not None and uncond_cross_attn_mask is not None: - ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) - - # Normalize timestep for the transformer - t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + # Duplicate latents if using classifier-free guidance + if self.do_classifier_free_guidance: + latents_in = torch.cat([latents, latents], dim=0) + # Cross-attention batch (uncond, cond) + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + else: + latents_in = latents + ca_embed = text_embeddings + ca_mask = cross_attn_mask + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) # Process inputs for transformer img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed) @@ -597,11 +627,12 @@ class PhotonPipeline( ) # Convert back to image format - noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) + noise_pred = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) # Apply CFG - noise_uncond, noise_text = noise_both.chunk(2, dim=0) - noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + if self.do_classifier_free_guidance: + noise_uncond, noise_text = noise_pred.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) # Compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample