mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -9,22 +9,26 @@ from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
|
||||
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers.utils import make_image_grid
|
||||
|
||||
|
||||
class CLIPScore(nn.Module):
|
||||
def __init__(self, model_id: str = "openai/clip-vit-large-patch14", device: Optional[str] = None) -> None:
|
||||
def __init__(self, model_id: str = "openai/clip-vit-large-patch14", device: str = None) -> None:
|
||||
super().__init__()
|
||||
self.model = CLIPModel.from_pretrained(model_id)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_id)
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(model_id)
|
||||
if device is not None:
|
||||
self.model = self.model.to(device)
|
||||
self.model = self.model.to(dtype=torch.float32)
|
||||
self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
self.eval()
|
||||
|
||||
def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor: # type: ignore[override]
|
||||
def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor:
|
||||
pixel_values = self._prepare_images(images)
|
||||
device = next(self.model.parameters()).device
|
||||
pixel_values = pixel_values.to(device=device, dtype=self.model.dtype)
|
||||
text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
|
||||
|
||||
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
||||
@@ -54,14 +58,10 @@ class CLIPScore(nn.Module):
|
||||
1, -1, 1, 1
|
||||
)
|
||||
pixel_values = (pixel_values - mean) / std
|
||||
return pixel_values.to(self.model.dtype)
|
||||
return pixel_values.to(dtype=torch.float32)
|
||||
|
||||
|
||||
class AdversarialFluxPipeline(FluxPipeline):
|
||||
def __init__(self, *args, reward_model: Optional[nn.Module] = None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.reward_model = reward_model
|
||||
|
||||
def adversarial_refinement(
|
||||
self,
|
||||
prompt: Union[str, list[str]],
|
||||
@@ -84,13 +84,10 @@ class AdversarialFluxPipeline(FluxPipeline):
|
||||
|
||||
flux_output = super().__call__(prompt=prompt, **generate_kwargs)
|
||||
latents = flux_output.images
|
||||
|
||||
device = latents.device
|
||||
latents = latents.to(torch.float32)
|
||||
|
||||
if self.reward_model is None:
|
||||
self.reward_model = CLIPScore(model_id=clip_model_id, device=device)
|
||||
self.reward_model = self.reward_model.to(device)
|
||||
|
||||
reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0])
|
||||
|
||||
@@ -98,9 +95,9 @@ class AdversarialFluxPipeline(FluxPipeline):
|
||||
current_images = self._decode_packed_latents(latents, height, width)
|
||||
current_scores = self.reward_model(current_images, reward_prompts)
|
||||
|
||||
intermediate_images: list[list[Image.Image]] = []
|
||||
intermediate_images: list[Image.Image] = []
|
||||
if record_intermediate:
|
||||
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
|
||||
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0])
|
||||
|
||||
score_trace: list[float] = [current_scores.mean().item()]
|
||||
|
||||
@@ -123,15 +120,15 @@ class AdversarialFluxPipeline(FluxPipeline):
|
||||
|
||||
score_trace.append(current_scores.mean().item())
|
||||
if record_intermediate:
|
||||
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
|
||||
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0])
|
||||
|
||||
if record_intermediate and intermediate_images:
|
||||
final_images = intermediate_images[-1]
|
||||
final_image = intermediate_images[-1]
|
||||
else:
|
||||
final_images = self.image_processor.postprocess(current_images, output_type="pil")
|
||||
final_image = self.image_processor.postprocess(current_images, output_type="pil")[0]
|
||||
|
||||
return {
|
||||
"images": final_images,
|
||||
"final_image": final_image,
|
||||
"latents": latents.detach(),
|
||||
"score_trace": score_trace,
|
||||
"final_scores": current_scores.detach().cpu().tolist(),
|
||||
@@ -141,13 +138,12 @@ class AdversarialFluxPipeline(FluxPipeline):
|
||||
def _decode_packed_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
return self.vae.decode(unpacked, return_dict=False)[0]
|
||||
decoded = self.vae.decode(unpacked, return_dict=False)[0]
|
||||
return decoded
|
||||
|
||||
def _resolve_height_width(self, height: Optional[int], width: Optional[int]) -> tuple[int, int]:
|
||||
def _resolve_height_width(self, height: int, width: int) -> tuple[int, int]:
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
return height, width
|
||||
|
||||
@staticmethod
|
||||
@@ -209,16 +205,15 @@ def main() -> None:
|
||||
record_intermediate=record_intermediate,
|
||||
)
|
||||
|
||||
image = result["images"][0]
|
||||
image.save(args.output)
|
||||
result["final_image"].save(args.output)
|
||||
|
||||
if args.intermediate_dir and result["intermediate_images"]:
|
||||
output_dir = Path(args.intermediate_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
for round_idx, pil_images in enumerate(result["intermediate_images"]):
|
||||
for sample_idx, pil_image in enumerate(pil_images):
|
||||
filename = output_dir / f"round_{round_idx:02d}_sample_{sample_idx:02d}.png"
|
||||
pil_image.save(filename)
|
||||
images = result["intermediate_images"]
|
||||
image_grid = make_image_grid(images, cols=len(images), rows=1)
|
||||
filename = output_dir / "image_grid.png"
|
||||
image_grid.save(filename)
|
||||
|
||||
print("Average CLIP score trace:", result["score_trace"])
|
||||
print("Final per-sample CLIP scores:", result["final_scores"])
|
||||
|
||||
Reference in New Issue
Block a user