diff --git a/examples/community/flux_adversarial_latents.py b/examples/community/flux_adversarial_latents.py index 1f705fb260..acb6a71cc7 100644 --- a/examples/community/flux_adversarial_latents.py +++ b/examples/community/flux_adversarial_latents.py @@ -9,22 +9,26 @@ from PIL import Image from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer from diffusers import FluxPipeline +from diffusers.utils import make_image_grid class CLIPScore(nn.Module): - def __init__(self, model_id: str = "openai/clip-vit-large-patch14", device: Optional[str] = None) -> None: + def __init__(self, model_id: str = "openai/clip-vit-large-patch14", device: str = None) -> None: super().__init__() self.model = CLIPModel.from_pretrained(model_id) self.tokenizer = CLIPTokenizer.from_pretrained(model_id) self.image_processor = CLIPImageProcessor.from_pretrained(model_id) if device is not None: self.model = self.model.to(device) + self.model = self.model.to(dtype=torch.float32) + self.model.eval() + for param in self.model.parameters(): + param.requires_grad_(False) self.eval() - def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor: # type: ignore[override] + def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor: pixel_values = self._prepare_images(images) device = next(self.model.parameters()).device - pixel_values = pixel_values.to(device=device, dtype=self.model.dtype) text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device) image_embeds = self.model.get_image_features(pixel_values=pixel_values) @@ -54,14 +58,10 @@ class CLIPScore(nn.Module): 1, -1, 1, 1 ) pixel_values = (pixel_values - mean) / std - return pixel_values.to(self.model.dtype) + return pixel_values.to(dtype=torch.float32) class AdversarialFluxPipeline(FluxPipeline): - def __init__(self, *args, reward_model: Optional[nn.Module] = None, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.reward_model = reward_model - def adversarial_refinement( self, prompt: Union[str, list[str]], @@ -84,13 +84,10 @@ class AdversarialFluxPipeline(FluxPipeline): flux_output = super().__call__(prompt=prompt, **generate_kwargs) latents = flux_output.images - device = latents.device - latents = latents.to(torch.float32) if self.reward_model is None: self.reward_model = CLIPScore(model_id=clip_model_id, device=device) - self.reward_model = self.reward_model.to(device) reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0]) @@ -98,9 +95,9 @@ class AdversarialFluxPipeline(FluxPipeline): current_images = self._decode_packed_latents(latents, height, width) current_scores = self.reward_model(current_images, reward_prompts) - intermediate_images: list[list[Image.Image]] = [] + intermediate_images: list[Image.Image] = [] if record_intermediate: - intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")) + intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0]) score_trace: list[float] = [current_scores.mean().item()] @@ -123,15 +120,15 @@ class AdversarialFluxPipeline(FluxPipeline): score_trace.append(current_scores.mean().item()) if record_intermediate: - intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")) + intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0]) if record_intermediate and intermediate_images: - final_images = intermediate_images[-1] + final_image = intermediate_images[-1] else: - final_images = self.image_processor.postprocess(current_images, output_type="pil") + final_image = self.image_processor.postprocess(current_images, output_type="pil")[0] return { - "images": final_images, + "final_image": final_image, "latents": latents.detach(), "score_trace": score_trace, "final_scores": current_scores.detach().cpu().tolist(), @@ -141,13 +138,12 @@ class AdversarialFluxPipeline(FluxPipeline): def _decode_packed_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor: unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor) unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor - return self.vae.decode(unpacked, return_dict=False)[0] + decoded = self.vae.decode(unpacked, return_dict=False)[0] + return decoded - def _resolve_height_width(self, height: Optional[int], width: Optional[int]) -> tuple[int, int]: + def _resolve_height_width(self, height: int, width: int) -> tuple[int, int]: height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) return height, width @staticmethod @@ -209,16 +205,15 @@ def main() -> None: record_intermediate=record_intermediate, ) - image = result["images"][0] - image.save(args.output) + result["final_image"].save(args.output) if args.intermediate_dir and result["intermediate_images"]: output_dir = Path(args.intermediate_dir) output_dir.mkdir(parents=True, exist_ok=True) - for round_idx, pil_images in enumerate(result["intermediate_images"]): - for sample_idx, pil_image in enumerate(pil_images): - filename = output_dir / f"round_{round_idx:02d}_sample_{sample_idx:02d}.png" - pil_image.save(filename) + images = result["intermediate_images"] + image_grid = make_image_grid(images, cols=len(images), rows=1) + filename = output_dir / "image_grid.png" + image_grid.save(filename) print("Average CLIP score trace:", result["score_trace"]) print("Final per-sample CLIP scores:", result["final_scores"])