mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -27,8 +27,8 @@ class CLIPScore(nn.Module):
|
||||
self.eval()
|
||||
|
||||
def forward(self, images: torch.Tensor, prompts: list[str]) -> torch.Tensor:
|
||||
pixel_values = self._prepare_images(images)
|
||||
device = next(self.model.parameters()).device
|
||||
pixel_values = self._prepare_images(images).to(device=device, dtype=torch.float32)
|
||||
text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
|
||||
|
||||
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
||||
@@ -92,7 +92,7 @@ class AdversarialFluxPipeline(FluxPipeline):
|
||||
reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
current_images = self._decode_packed_latents(latents, height, width)
|
||||
current_images = self._decode_packed_latents(latents, height, width).to(dtype=torch.float32)
|
||||
current_scores = self.reward_model(current_images, reward_prompts)
|
||||
|
||||
intermediate_images: list[Image.Image] = []
|
||||
@@ -102,30 +102,26 @@ class AdversarialFluxPipeline(FluxPipeline):
|
||||
score_trace: list[float] = [current_scores.mean().item()]
|
||||
|
||||
for _ in range(num_rounds):
|
||||
latents.requires_grad_(True)
|
||||
decoded = self._decode_packed_latents(latents, height, width)
|
||||
scores = self.reward_model(decoded, reward_prompts)
|
||||
current_images.requires_grad_(True)
|
||||
scores = self.reward_model(current_images, reward_prompts)
|
||||
total_score = scores.mean()
|
||||
|
||||
grad = torch.autograd.grad(total_score, latents)[0]
|
||||
grad = torch.autograd.grad(total_score, current_images)[0]
|
||||
|
||||
with torch.no_grad():
|
||||
latents = latents + step_size * grad
|
||||
current_images = current_images + step_size * grad
|
||||
current_images = current_images.clamp_(-1.0, 1.0)
|
||||
|
||||
latents = latents.detach()
|
||||
current_images = current_images.detach()
|
||||
|
||||
with torch.no_grad():
|
||||
current_images = self._decode_packed_latents(latents, height, width)
|
||||
current_scores = self.reward_model(current_images, reward_prompts)
|
||||
|
||||
score_trace.append(current_scores.mean().item())
|
||||
if record_intermediate:
|
||||
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil")[0])
|
||||
|
||||
if record_intermediate and intermediate_images:
|
||||
final_image = intermediate_images[-1]
|
||||
else:
|
||||
final_image = self.image_processor.postprocess(current_images, output_type="pil")[0]
|
||||
final_image = self.image_processor.postprocess(current_images, output_type="pil")[0]
|
||||
|
||||
return {
|
||||
"final_image": final_image,
|
||||
@@ -139,7 +135,7 @@ class AdversarialFluxPipeline(FluxPipeline):
|
||||
unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
decoded = self.vae.decode(unpacked, return_dict=False)[0]
|
||||
return decoded
|
||||
return decoded.to(dtype=torch.float32)
|
||||
|
||||
def _resolve_height_width(self, height: int, width: int) -> tuple[int, int]:
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
Reference in New Issue
Block a user