1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Big refactor] move unets to unets module 🦋 (#6630)

* move unets to  module 🦋

* parameterize unet-level import.

* fix flax unet2dcondition model import

* models __init__

* mildly depcrecating models.unet_2d_blocks in favor of models.unets.unet_2d_blocks.

* noqa

* correct depcrecation behaviour

* inherit from the actual classes.

* Empty-Commit

* backwards compatibility for unet_2d.py

* backward compatibility for unet_2d_condition

* bc for unet_1d

* bc for unet_1d_blocks
This commit is contained in:
Sayak Paul
2024-01-23 08:57:58 +05:30
committed by GitHub
parent 5e96333cb2
commit 1f0705adcf
51 changed files with 6619 additions and 6106 deletions

View File

@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNetMotionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput

View File

@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet1DModel
## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput
[[autodoc]] models.unets.unet_1d.UNet1DOutput

View File

@@ -22,10 +22,10 @@ The abstract from the paper is:
[[autodoc]] UNet2DConditionModel
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput

View File

@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet2DModel
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
[[autodoc]] models.unets.unet_2d.UNet2DOutput

View File

@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet3DConditionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput

View File

@@ -26,7 +26,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (

View File

@@ -8,7 +8,7 @@ import torch
from diffusers import StableDiffusionControlNetPipeline
from diffusers.models import ControlNetModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import logging

View File

@@ -7,7 +7,7 @@ import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import PIL_INTERPOLATION, logging

View File

@@ -8,7 +8,7 @@ import torch
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import (
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,

View File

@@ -26,7 +26,7 @@ from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProc
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
@@ -36,7 +36,7 @@ from diffusers.models.unet_2d_blocks import (
UpBlock2D,
Upsample2D,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils import BaseOutput, logging

View File

@@ -10,7 +10,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import VQModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.uvit_2d import UVit2DModel
from diffusers.models.unets.uvit_2d import UVit2DModel
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
from diffusers.schedulers import AmusedScheduler

View File

@@ -14,7 +14,7 @@ from tqdm import tqdm
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
args = ArgumentParser()

View File

@@ -382,7 +382,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
@@ -711,7 +711,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (

View File

@@ -16,7 +16,7 @@ import numpy as np
import torch
import tqdm
from ...models.unet_1d import UNet1DModel
from ...models.unets.unet_1d import UNet1DModel
from ...pipelines import DiffusionPipeline
from ...utils.dummy_pt_objects import DDPMScheduler
from ...utils.torch_utils import randn_tensor

View File

@@ -39,19 +39,19 @@ if is_torch_available():
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unet_1d"] = ["UNet1DModel"]
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["uvit_2d"] = ["UVit2DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
@@ -73,19 +73,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .uvit_2d import UVit2DModel
from .unets import (
Kandinsky3UNet,
MotionAdapter,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
)
from .vq_model import VQModel
if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
else:

View File

@@ -157,7 +157,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -181,7 +181,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -216,7 +216,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -448,7 +448,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
@@ -472,7 +472,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.

View File

@@ -23,7 +23,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -242,7 +242,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -266,7 +266,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.

View File

@@ -31,7 +31,7 @@ from ..attention_processor import (
AttnProcessor,
)
from ..modeling_utils import ModelMixin
from ..unet_2d import UNet2DModel
from ..unets.unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -187,7 +187,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.

View File

@@ -22,7 +22,7 @@ from ...utils import BaseOutput, is_torch_version
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
from ..unet_2d_blocks import (
from ..unets.unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,

View File

@@ -30,8 +30,14 @@ from .attention_processor import (
)
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
from .unet_2d_condition import UNet2DConditionModel
from .unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from .unets.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -509,7 +515,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -533,7 +539,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -568,7 +574,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -584,7 +590,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.

View File

@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
from .unets.unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
@@ -329,14 +329,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
channel_order = self.controlnet_conditioning_channel_order

View File

@@ -120,7 +120,7 @@ class DualTransformer2DModel(nn.Module):
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:

View File

@@ -167,7 +167,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -191,7 +191,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -226,7 +226,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.

View File

@@ -286,7 +286,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:

View File

@@ -149,7 +149,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:

View File

@@ -12,244 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
from ..utils import deprecate
from .unets.unet_1d import UNet1DModel, UNet1DOutput
@dataclass
class UNet1DOutput(BaseOutput):
"""
The output of [`UNet1DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
class UNet1DOutput(UNet1DOutput):
deprecation_message = "Importing `UNet1DOutput` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DOutput`, instead."
deprecate("UNet1DOutput", "0.29", deprecation_message)
class UNet1DModel(ModelMixin, ConfigMixin):
r"""
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
extra_in_channels (`int`, *optional*, defaults to 0):
Number of additional channels to be added to the input of the first down block. Useful for cases where the
input data has more channels than what the model was initially designed for.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
downsample_each_block (`int`, *optional*, defaults to `False`):
Experimental feature for using a UNet without upsampling.
"""
@register_to_config
def __init__(
self,
sample_size: int = 65536,
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
):
super().__init__()
self.sample_size = sample_size
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.out_block = None
# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
if i == 0:
input_channel += extra_in_channels
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
The [`UNet1DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
down_block_res_samples += res_samples
# 3. mid
if self.mid_block:
sample = self.mid_block(sample, timestep_embed)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
# 5. post-process
if self.out_block:
sample = self.out_block(sample, timestep_embed)
if not return_dict:
return (sample,)
return UNet1DOutput(sample=sample)
class UNet1DModel(UNet1DModel):
deprecation_message = "Importing `UNet1DModel` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DModel`, instead."
deprecate("UNet1DModel", "0.29", deprecation_message)

View File

@@ -11,616 +11,112 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from .activations import get_activation
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
class DownResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
conv_shortcut: bool = False,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.add_downsample = add_downsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
return hidden_states, output_states
class UpResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_upsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.time_embedding_norm = time_embedding_norm
self.add_upsample = add_upsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
return hidden_states
class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.embed_dim = embed_dim
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
x = self.down2(x)
return x
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
embed_dim: int,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
non_linearity: Optional[str] = None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_downsample = add_downsample
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Downsample1D(out_channels, use_conv=True)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True)
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
self.downsample = self.downsample(hidden_states)
return hidden_states
class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_act(hidden_states)
hidden_states = self.final_conv1d_2(hidden_states)
return hidden_states
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
super().__init__()
self.final_block = nn.ModuleList(
[
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
get_activation(act_fn),
nn.Linear(fc_dim // 2, 1),
]
)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
hidden_states = layer(hidden_states)
return hidden_states
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
"lanczos3": [
0.003689131001010537,
0.015056144446134567,
-0.03399861603975296,
-0.066637322306633,
0.13550527393817902,
0.44638532400131226,
0.44638532400131226,
0.13550527393817902,
-0.066637322306633,
-0.03399861603975296,
0.015056144446134567,
0.003689131001010537,
],
}
class Downsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv1d(hidden_states, weight, stride=2)
class Upsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
class SelfAttention1d(nn.Module):
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
self.num_heads = n_head
self.query = nn.Linear(self.channels, self.channels)
self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels)
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores, dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.dropout(hidden_states)
output = hidden_states + residual
return output
class ResConvBlock(nn.Module):
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
if self.has_conv_skip:
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
self.gelu_1 = nn.GELU()
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
if not self.is_last:
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
hidden_states = self.group_norm_1(hidden_states)
hidden_states = self.gelu_1(hidden_states)
hidden_states = self.conv_2(hidden_states)
from ..utils import deprecate
from .unets.unet_1d_blocks import (
AttnDownBlock1D,
AttnUpBlock1D,
DownBlock1D,
DownBlock1DNoSkip,
DownResnetBlock1D,
Downsample1d,
MidResTemporalBlock1D,
OutConv1DBlock,
OutValueFunctionBlock,
ResConvBlock,
SelfAttention1d,
UNetMidBlock1D,
UpBlock1D,
UpBlock1DNoSkip,
UpResnetBlock1D,
Upsample1d,
ValueFunctionMidBlock1D,
)
if not self.is_last:
hidden_states = self.group_norm_2(hidden_states)
hidden_states = self.gelu_2(hidden_states)
output = hidden_states + residual
return output
class DownResnetBlock1D(DownResnetBlock1D):
deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead."
deprecate("DownResnetBlock1D", "0.29", deprecation_message)
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__()
class UpResnetBlock1D(UpResnetBlock1D):
deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead."
deprecate("UpResnetBlock1D", "0.29", deprecation_message)
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.up = Upsample1d(kernel="cubic")
class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D):
deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead."
deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
class OutConv1DBlock(OutConv1DBlock):
deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead."
deprecate("OutConv1DBlock", "0.29", deprecation_message)
hidden_states = self.up(hidden_states)
return hidden_states
class OutValueFunctionBlock(OutValueFunctionBlock):
deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead."
deprecate("OutValueFunctionBlock", "0.29", deprecation_message)
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
class Downsample1d(Downsample1d):
deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead."
deprecate("Downsample1d", "0.29", deprecation_message)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
class Upsample1d(Upsample1d):
deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead."
deprecate("Upsample1d", "0.29", deprecation_message)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
class SelfAttention1d(SelfAttention1d):
deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead."
deprecate("SelfAttention1d", "0.29", deprecation_message)
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
class ResConvBlock(ResConvBlock):
deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead."
deprecate("ResConvBlock", "0.29", deprecation_message)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
class UNetMidBlock1D(UNetMidBlock1D):
deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead."
deprecate("UNetMidBlock1D", "0.29", deprecation_message)
hidden_states = self.up(hidden_states)
return hidden_states
class AttnDownBlock1D(AttnDownBlock1D):
deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead."
deprecate("AttnDownBlock1D", "0.29", deprecation_message)
class UpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
class DownBlock1D(DownBlock1D):
deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead."
deprecate("DownBlock1D", "0.29", deprecation_message)
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
class DownBlock1DNoSkip(DownBlock1DNoSkip):
deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead."
deprecate("DownBlock1DNoSkip", "0.29", deprecation_message)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
class AttnUpBlock1D(AttnUpBlock1D):
deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead."
deprecate("AttnUpBlock1D", "0.29", deprecation_message)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1D(UpBlock1D):
deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead."
deprecate("UpBlock1D", "0.29", deprecation_message)
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
class UpBlock1DNoSkip(UpBlock1DNoSkip):
deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead."
deprecate("UpBlock1DNoSkip", "0.29", deprecation_message)
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
]
self.resnets = nn.ModuleList(resnets)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
class MidResTemporalBlock1D(MidResTemporalBlock1D):
deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead."
deprecate("MidResTemporalBlock1D", "0.29", deprecation_message)
def get_down_block(
@@ -630,42 +126,38 @@ def get_down_block(
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
):
deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead."
deprecate("get_down_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_down_block
return get_down_block(
down_block_type=down_block_type,
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
def get_up_block(
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
) -> UpBlockType:
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
elif up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
):
deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead."
deprecate("get_up_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_up_block
return get_up_block(
up_block_type=up_block_type,
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
def get_mid_block(
@@ -676,27 +168,36 @@ def get_mid_block(
out_channels: int,
embed_dim: int,
add_downsample: bool,
) -> MidBlockType:
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
elif mid_block_type == "ValueFunctionMidBlock1D":
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
elif mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
):
deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead."
deprecate("get_mid_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_mid_block
return get_mid_block(
mid_block_type=mid_block_type,
num_layers=num_layers,
in_channels=in_channels,
mid_channels=mid_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
def get_out_block(
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
) -> Optional[OutBlockType]:
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
return None
):
deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead."
deprecate("get_out_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_out_block
return get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=embed_dim,
out_channels=out_channels,
act_fn=act_fn,
fc_dim=fc_dim,
)

View File

@@ -11,336 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
The output of [`UNet2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
from ..utils import deprecate
from .unets.unet_2d import UNet2DModel, UNet2DOutput
class UNet2DModel(ModelMixin, ConfigMixin):
r"""
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
class UNet2DOutput(UNet2DOutput):
deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead."
deprecate("UNet2DOutput", "0.29", deprecation_message)
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
1)`.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
downsample_type (`str`, *optional*, defaults to `conv`):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
conditioning with `class_embed_type` equal to `None`.
"""
@register_to_config
def __init__(
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int] = (224, 448, 672, 896),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
num_train_timesteps: Optional[int] = None,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
dropout=dropout,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
r"""
The [`UNet2DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
if not return_dict:
return (sample,)
return UNet2DOutput(sample=sample)
class UNet2DModel(UNet2DModel):
deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead."
deprecate("UNet2DModel", "0.29", deprecation_message)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
from ...utils import is_flax_available, is_torch_available
if is_torch_available():
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .uvit_2d import UVit2DModel
if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel

View File

@@ -0,0 +1,255 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@dataclass
class UNet1DOutput(BaseOutput):
"""
The output of [`UNet1DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
class UNet1DModel(ModelMixin, ConfigMixin):
r"""
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
extra_in_channels (`int`, *optional*, defaults to 0):
Number of additional channels to be added to the input of the first down block. Useful for cases where the
input data has more channels than what the model was initially designed for.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
downsample_each_block (`int`, *optional*, defaults to `False`):
Experimental feature for using a UNet without upsampling.
"""
@register_to_config
def __init__(
self,
sample_size: int = 65536,
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
):
super().__init__()
self.sample_size = sample_size
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.out_block = None
# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
if i == 0:
input_channel += extra_in_channels
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
The [`UNet1DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
down_block_res_samples += res_samples
# 3. mid
if self.mid_block:
sample = self.mid_block(sample, timestep_embed)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
# 5. post-process
if self.out_block:
sample = self.out_block(sample, timestep_embed)
if not return_dict:
return (sample,)
return UNet1DOutput(sample=sample)

View File

@@ -0,0 +1,702 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..activations import get_activation
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
class DownResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
conv_shortcut: bool = False,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.add_downsample = add_downsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
return hidden_states, output_states
class UpResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_upsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.time_embedding_norm = time_embedding_norm
self.add_upsample = add_upsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
return hidden_states
class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.embed_dim = embed_dim
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
x = self.down2(x)
return x
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
embed_dim: int,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
non_linearity: Optional[str] = None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_downsample = add_downsample
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Downsample1D(out_channels, use_conv=True)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True)
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
self.downsample = self.downsample(hidden_states)
return hidden_states
class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_act(hidden_states)
hidden_states = self.final_conv1d_2(hidden_states)
return hidden_states
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
super().__init__()
self.final_block = nn.ModuleList(
[
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
get_activation(act_fn),
nn.Linear(fc_dim // 2, 1),
]
)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
hidden_states = layer(hidden_states)
return hidden_states
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
"lanczos3": [
0.003689131001010537,
0.015056144446134567,
-0.03399861603975296,
-0.066637322306633,
0.13550527393817902,
0.44638532400131226,
0.44638532400131226,
0.13550527393817902,
-0.066637322306633,
-0.03399861603975296,
0.015056144446134567,
0.003689131001010537,
],
}
class Downsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv1d(hidden_states, weight, stride=2)
class Upsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
class SelfAttention1d(nn.Module):
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
self.num_heads = n_head
self.query = nn.Linear(self.channels, self.channels)
self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels)
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores, dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.dropout(hidden_states)
output = hidden_states + residual
return output
class ResConvBlock(nn.Module):
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
if self.has_conv_skip:
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
self.gelu_1 = nn.GELU()
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
if not self.is_last:
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
hidden_states = self.group_norm_1(hidden_states)
hidden_states = self.gelu_1(hidden_states)
hidden_states = self.conv_2(hidden_states)
if not self.is_last:
hidden_states = self.group_norm_2(hidden_states)
hidden_states = self.gelu_2(hidden_states)
output = hidden_states + residual
return output
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.up = Upsample1d(kernel="cubic")
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
]
self.resnets = nn.ModuleList(resnets)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
) -> UpBlockType:
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
elif up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(
mid_block_type: str,
num_layers: int,
in_channels: int,
mid_channels: int,
out_channels: int,
embed_dim: int,
add_downsample: bool,
) -> MidBlockType:
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
elif mid_block_type == "ValueFunctionMidBlock1D":
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
elif mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
def get_out_block(
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
) -> Optional[OutBlockType]:
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
return None

View File

@@ -0,0 +1,346 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
The output of [`UNet2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
class UNet2DModel(ModelMixin, ConfigMixin):
r"""
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
1)`.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
downsample_type (`str`, *optional*, defaults to `conv`):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
conditioning with `class_embed_type` equal to `None`.
"""
@register_to_config
def __init__(
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int] = (224, 448, 672, 896),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
num_train_timesteps: Optional[int] = None,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
dropout=dropout,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
r"""
The [`UNet2DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
if not return_dict:
return (sample,)
return UNet2DOutput(sample=sample)

File diff suppressed because it is too large Load Diff

View File

@@ -15,8 +15,8 @@
import flax.linen as nn
import jax.numpy as jnp
from .attention_flax import FlaxTransformer2DModel
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module):

File diff suppressed because it is too large Load Diff

View File

@@ -19,10 +19,10 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .modeling_flax_utils import FlaxModelMixin
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxCrossAttnUpBlock2D,
@@ -342,14 +342,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
# 1. time

View File

@@ -17,19 +17,19 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ..utils import is_torch_version
from ..utils.torch_utils import apply_freeu
from .attention import Attention
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import (
from ...utils import is_torch_version
from ...utils.torch_utils import apply_freeu
from ..attention import Attention
from ..dual_transformer_2d import DualTransformer2DModel
from ..resnet import (
Downsample2D,
ResnetBlock2D,
SpatioTemporalResBlock,
TemporalConvLayer,
Upsample2D,
)
from .transformer_2d import Transformer2DModel
from .transformer_temporal import (
from ..transformer_2d import Transformer2DModel
from ..transformer_temporal import (
TransformerSpatioTemporalModel,
TransformerTemporalModel,
)

View File

@@ -20,20 +20,20 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, deprecate, logging
from .activations import get_activation
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging
from ..activations import get_activation
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
@@ -284,7 +284,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -308,7 +308,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
@@ -374,7 +374,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -449,7 +449,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -469,7 +469,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -494,7 +494,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
@@ -503,7 +503,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(

View File

@@ -19,11 +19,11 @@ import torch
import torch.utils.checkpoint
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import Attention, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -17,19 +17,19 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import logging
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_blocks import (
@@ -524,7 +524,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -548,7 +548,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -583,7 +583,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
@@ -613,7 +613,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
@@ -625,7 +625,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self) -> None:
"""
Disables custom attention processors and sets the default attention implementation.
@@ -645,7 +645,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -670,7 +670,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self) -> None:
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}

View File

@@ -4,12 +4,12 @@ from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, logging
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
@@ -323,7 +323,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward

View File

@@ -20,20 +20,20 @@ import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import PeftAdapterMixin
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, get_timestep_embedding
from .modeling_utils import ModelMixin
from .normalization import GlobalResponseNorm, RMSNorm
from .resnet import Downsample2D, Upsample2D
from ..embeddings import TimestepEmbedding, get_timestep_embedding
from ..modeling_utils import ModelMixin
from ..normalization import GlobalResponseNorm, RMSNorm
from ..resnet import Downsample2D, Upsample2D
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
@@ -213,7 +213,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return logits
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -237,7 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -272,7 +272,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.

View File

@@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,

View File

@@ -36,8 +36,8 @@ from ...models.embeddings import (
from ...models.modeling_utils import ModelMixin
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging
@@ -513,7 +513,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -537,7 +537,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -572,7 +572,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -588,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
@@ -654,7 +654,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@@ -687,7 +687,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
@@ -700,8 +700,8 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
which adds large negative values to the attention scores corresponding to "discard" tokens.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.

View File

@@ -33,7 +33,7 @@ from ....models.embeddings import (
)
from ....models.resnet import ResnetBlockCondNorm2D
from ....models.transformer_2d import Transformer2DModel
from ....models.unet_2d_condition import UNet2DConditionOutput
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu
@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
return objs
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
@@ -1095,7 +1096,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
@@ -1111,8 +1112,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
@@ -1785,7 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return hidden_states, output_states
# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class UpBlockFlat(nn.Module):
def __init__(
self,
@@ -1896,7 +1897,7 @@ class UpBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class CrossAttnUpBlockFlat(nn.Module):
def __init__(
self,
@@ -2070,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
@@ -2226,7 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__(
self,
@@ -2373,7 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
def __init__(
self,

View File

@@ -752,7 +752,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
cross_attention_kwargs (*optional*):
Keyword arguments to supply to the cross attention layers, if used.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the

View File

@@ -66,7 +66,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.set_default_attn_processor()
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -90,7 +90,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -125,7 +125,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import unittest
from diffusers.models.unet_2d_blocks import * # noqa F403
from diffusers.models.unets.unet_2d_blocks import * # noqa F403
from diffusers.utils.testing_utils import torch_device
from .test_unet_blocks_common import UNetBlockTesterMixin

View File

@@ -28,7 +28,7 @@ from diffusers import (
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
from diffusers.models.unet_2d_blocks import UNetMidBlock2D
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device