mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Revert "Fix code style with make style."
This reverts commit 10a174a12c.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Tuple, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -21,7 +21,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
logger.warning(
|
||||
@@ -45,7 +45,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.0))
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
@@ -53,13 +53,14 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean},
|
||||
\text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for
|
||||
generating the random values works best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
@@ -67,7 +68,8 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
@@ -100,7 +102,7 @@ class PatchEmbed(nn.Module):
|
||||
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
|
||||
self.use_pos_embed = use_pos_embed
|
||||
if self.use_pos_embed:
|
||||
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
|
||||
@@ -913,7 +915,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
|
||||
self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim))
|
||||
self.pos_embed_drop = nn.Dropout(p=dropout)
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# 2. Define transformer blocks
|
||||
self.transformer = UTransformer2DModel(
|
||||
@@ -947,10 +949,10 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
|
||||
self.vae_img_out = nn.Linear(self.inner_dim, patch_dim)
|
||||
self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim)
|
||||
self.text_out = nn.Linear(self.inner_dim, text_dim)
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {"pos_embed"}
|
||||
return {'pos_embed'}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -161,7 +161,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
# TODO: handle safety checking?
|
||||
self.safety_checker = None
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
@@ -754,7 +754,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
|
||||
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t)
|
||||
_, _, text_out_uncond = self.unet(
|
||||
img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t
|
||||
)
|
||||
|
||||
img_vae_out_uncond, img_clip_out_uncond, _ = self.unet(
|
||||
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
|
||||
@@ -988,9 +990,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
full_latents_available = latents is not None
|
||||
individual_latents_available = (
|
||||
prompt_latents is not None or vae_latents is not None or clip_latents is not None
|
||||
)
|
||||
individual_latents_available = prompt_latents is not None or vae_latents is not None or clip_latents is not None
|
||||
if full_latents_available and individual_latents_available:
|
||||
logger.warning(
|
||||
"You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and"
|
||||
|
||||
Reference in New Issue
Block a user