1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-10-10 11:29:19 +05:30
parent 44126bd77e
commit 53476bfca9

View File

@@ -27,8 +27,8 @@ class CLIPScore(nn.Module):
self.eval()
def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor:
pixel_values = self._prepare_images(images)
device = next(self.model.parameters()).device
pixel_values = self._prepare_images(images).to(device=device, dtype=torch.float32)
text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
@@ -92,7 +92,7 @@ class AdversarialFluxPipeline(FluxPipeline):
reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0])
with torch.no_grad():
current_images = self._decode_packed_latents(latents, height, width)
current_images = self._decode_packed_latents(latents, height, width).to(dtype=torch.float32)
current_scores = self.reward_model(current_images, reward_prompts)
intermediate_images: list[Image.Image] = []
@@ -102,30 +102,26 @@ class AdversarialFluxPipeline(FluxPipeline):
score_trace: list[float] = [current_scores.mean().item()]
for _ in range(num_rounds):
latents.requires_grad_(True)
decoded = self._decode_packed_latents(latents, height, width)
scores = self.reward_model(decoded, reward_prompts)
current_images.requires_grad_(True)
scores = self.reward_model(current_images, reward_prompts)
total_score = scores.mean()
grad = torch.autograd.grad(total_score, latents)[0]
grad = torch.autograd.grad(total_score, current_images)[0]
with torch.no_grad():
latents = latents + step_size * grad
current_images = current_images + step_size * grad
current_images = current_images.clamp_(-1.0, 1.0)
latents = latents.detach()
current_images = current_images.detach()
with torch.no_grad():
current_images = self._decode_packed_latents(latents, height, width)
current_scores = self.reward_model(current_images, reward_prompts)
score_trace.append(current_scores.mean().item())
if record_intermediate:
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0])
if record_intermediate and intermediate_images:
final_image = intermediate_images[-1]
else:
final_image = self.image_processor.postprocess(current_images, output_type="pil")[0]
final_image = self.image_processor.postprocess(current_images, output_type="pil")[0]
return {
"final_image": final_image,
@@ -139,7 +135,7 @@ class AdversarialFluxPipeline(FluxPipeline):
unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor)
unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor
decoded = self.vae.decode(unpacked, return_dict=False)[0]
return decoded
return decoded.to(dtype=torch.float32)
def _resolve_height_width(self, height: int, width: int) -> tuple[int, int]:
height = height or self.default_sample_size * self.vae_scale_factor