1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Revert "Fix code style with make style."

This reverts commit 10a174a12c.
This commit is contained in:
Daniel Gu
2023-05-05 09:08:07 -07:00
parent fc8526354f
commit 9d39bef45f
2 changed files with 20 additions and 18 deletions

View File

@@ -1,5 +1,5 @@
import math
from typing import Optional, Union
from typing import Optional, Union, Tuple, List
import torch
from torch import nn
@@ -21,7 +21,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
logger.warning(
@@ -45,7 +45,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
@@ -53,13 +53,14 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean},
\text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for
generating the random values works best when :math:`a \leq \text{mean} \leq b`.
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
@@ -67,7 +68,8 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
@@ -100,7 +102,7 @@ class PatchEmbed(nn.Module):
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.use_pos_embed = use_pos_embed
if self.use_pos_embed:
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
@@ -913,7 +915,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim))
self.pos_embed_drop = nn.Dropout(p=dropout)
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.pos_embed, std=.02)
# 2. Define transformer blocks
self.transformer = UTransformer2DModel(
@@ -947,10 +949,10 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
self.vae_img_out = nn.Linear(self.inner_dim, patch_dim)
self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim)
self.text_out = nn.Linear(self.inner_dim, text_dim)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed"}
return {'pos_embed'}
def forward(
self,

View File

@@ -161,7 +161,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
# TODO: handle safety checking?
self.safety_checker = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
@@ -754,7 +754,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t)
_, _, text_out_uncond = self.unet(
img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t
)
img_vae_out_uncond, img_clip_out_uncond, _ = self.unet(
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
@@ -988,9 +990,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
)
full_latents_available = latents is not None
individual_latents_available = (
prompt_latents is not None or vae_latents is not None or clip_latents is not None
)
individual_latents_available = prompt_latents is not None or vae_latents is not None or clip_latents is not None
if full_latents_available and individual_latents_available:
logger.warning(
"You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and"