1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-10-09 10:23:28 +05:30
parent 8997e88d85
commit 44126bd77e

View File

@@ -9,22 +9,26 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
from diffusers import FluxPipeline
from diffusers.utils import make_image_grid
class CLIPScore(nn.Module):
def __init__(self, model_id: str = "openai/clip-vit-large-patch14", device: Optional[str] = None) -> None:
def __init__(self, model_id: str = "openai/clip-vit-large-patch14", device: str = None) -> None:
super().__init__()
self.model = CLIPModel.from_pretrained(model_id)
self.tokenizer = CLIPTokenizer.from_pretrained(model_id)
self.image_processor = CLIPImageProcessor.from_pretrained(model_id)
if device is not None:
self.model = self.model.to(device)
self.model = self.model.to(dtype=torch.float32)
self.model.eval()
for param in self.model.parameters():
param.requires_grad_(False)
self.eval()
def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor: # type: ignore[override]
def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor:
pixel_values = self._prepare_images(images)
device = next(self.model.parameters()).device
pixel_values = pixel_values.to(device=device, dtype=self.model.dtype)
text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
@@ -54,14 +58,10 @@ class CLIPScore(nn.Module):
1, -1, 1, 1
)
pixel_values = (pixel_values - mean) / std
return pixel_values.to(self.model.dtype)
return pixel_values.to(dtype=torch.float32)
class AdversarialFluxPipeline(FluxPipeline):
def __init__(self, *args, reward_model: Optional[nn.Module] = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.reward_model = reward_model
def adversarial_refinement(
self,
prompt: Union[str, list[str]],
@@ -84,13 +84,10 @@ class AdversarialFluxPipeline(FluxPipeline):
flux_output = super().__call__(prompt=prompt, **generate_kwargs)
latents = flux_output.images
device = latents.device
latents = latents.to(torch.float32)
if self.reward_model is None:
self.reward_model = CLIPScore(model_id=clip_model_id, device=device)
self.reward_model = self.reward_model.to(device)
reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0])
@@ -98,9 +95,9 @@ class AdversarialFluxPipeline(FluxPipeline):
current_images = self._decode_packed_latents(latents, height, width)
current_scores = self.reward_model(current_images, reward_prompts)
intermediate_images: list[list[Image.Image]] = []
intermediate_images: list[Image.Image] = []
if record_intermediate:
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0])
score_trace: list[float] = [current_scores.mean().item()]
@@ -123,15 +120,15 @@ class AdversarialFluxPipeline(FluxPipeline):
score_trace.append(current_scores.mean().item())
if record_intermediate:
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0])
if record_intermediate and intermediate_images:
final_images = intermediate_images[-1]
final_image = intermediate_images[-1]
else:
final_images = self.image_processor.postprocess(current_images, output_type="pil")
final_image = self.image_processor.postprocess(current_images, output_type="pil")[0]
return {
"images": final_images,
"final_image": final_image,
"latents": latents.detach(),
"score_trace": score_trace,
"final_scores": current_scores.detach().cpu().tolist(),
@@ -141,13 +138,12 @@ class AdversarialFluxPipeline(FluxPipeline):
def _decode_packed_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor)
unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor
return self.vae.decode(unpacked, return_dict=False)[0]
decoded = self.vae.decode(unpacked, return_dict=False)[0]
return decoded
def _resolve_height_width(self, height: Optional[int], width: Optional[int]) -> tuple[int, int]:
def _resolve_height_width(self, height: int, width: int) -> tuple[int, int]:
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
return height, width
@staticmethod
@@ -209,16 +205,15 @@ def main() -> None:
record_intermediate=record_intermediate,
)
image = result["images"][0]
image.save(args.output)
result["final_image"].save(args.output)
if args.intermediate_dir and result["intermediate_images"]:
output_dir = Path(args.intermediate_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for round_idx, pil_images in enumerate(result["intermediate_images"]):
for sample_idx, pil_image in enumerate(pil_images):
filename = output_dir / f"round_{round_idx:02d}_sample_{sample_idx:02d}.png"
pil_image.save(filename)
images = result["intermediate_images"]
image_grid = make_image_grid(images, cols=len(images), rows=1)
filename = output_dir / "image_grid.png"
image_grid.save(filename)
print("Average CLIP score trace:", result["score_trace"])
print("Final per-sample CLIP scores:", result["final_scores"])