From 53476bfca92cb7a2cb37c7340d9f9d24fa938528 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 10 Oct 2025 11:29:19 +0530 Subject: [PATCH] up --- .../community/flux_adversarial_latents.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/community/flux_adversarial_latents.py b/examples/community/flux_adversarial_latents.py index acb6a71cc7..b279778eb4 100644 --- a/examples/community/flux_adversarial_latents.py +++ b/examples/community/flux_adversarial_latents.py @@ -27,8 +27,8 @@ class CLIPScore(nn.Module): self.eval() def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor: - pixel_values = self._prepare_images(images) device = next(self.model.parameters()).device + pixel_values = self._prepare_images(images).to(device=device, dtype=torch.float32) text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device) image_embeds = self.model.get_image_features(pixel_values=pixel_values) @@ -92,7 +92,7 @@ class AdversarialFluxPipeline(FluxPipeline): reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0]) with torch.no_grad(): - current_images = self._decode_packed_latents(latents, height, width) + current_images = self._decode_packed_latents(latents, height, width).to(dtype=torch.float32) current_scores = self.reward_model(current_images, reward_prompts) intermediate_images: list[Image.Image] = [] @@ -102,30 +102,26 @@ class AdversarialFluxPipeline(FluxPipeline): score_trace: list[float] = [current_scores.mean().item()] for _ in range(num_rounds): - latents.requires_grad_(True) - decoded = self._decode_packed_latents(latents, height, width) - scores = self.reward_model(decoded, reward_prompts) + current_images.requires_grad_(True) + scores = self.reward_model(current_images, reward_prompts) total_score = scores.mean() - grad = torch.autograd.grad(total_score, latents)[0] + grad = torch.autograd.grad(total_score, current_images)[0] with torch.no_grad(): - latents = latents + step_size * grad + current_images = current_images + step_size * grad + current_images = current_images.clamp_(-1.0, 1.0) - latents = latents.detach() + current_images = current_images.detach() with torch.no_grad(): - current_images = self._decode_packed_latents(latents, height, width) current_scores = self.reward_model(current_images, reward_prompts) score_trace.append(current_scores.mean().item()) if record_intermediate: intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0]) - if record_intermediate and intermediate_images: - final_image = intermediate_images[-1] - else: - final_image = self.image_processor.postprocess(current_images, output_type="pil")[0] + final_image = self.image_processor.postprocess(current_images, output_type="pil")[0] return { "final_image": final_image, @@ -139,7 +135,7 @@ class AdversarialFluxPipeline(FluxPipeline): unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor) unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor decoded = self.vae.decode(unpacked, return_dict=False)[0] - return decoded + return decoded.to(dtype=torch.float32) def _resolve_height_width(self, height: int, width: int) -> tuple[int, int]: height = height or self.default_sample_size * self.vae_scale_factor