1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

move pipelines into folders

This commit is contained in:
Patrick von Platen
2022-06-28 10:47:47 +00:00
parent 0efac0aac9
commit bdecc3cffd
24 changed files with 70 additions and 52 deletions

View File

@@ -57,17 +57,19 @@ class DiffusionPipeline(ConfigMixin):
from diffusers import pipelines
for name, module in kwargs.items():
# check if the module is a pipeline module
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
# retrive library
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_file = module.__module__.split(".")[-1]
pipeline_dir = module.__module__.split(".")[-2]
is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# so we set the library to module name.
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = module.__module__.split(".")[-1]
library = pipeline_dir
# retrive class_name
class_name = module.__class__.__name__

View File

@@ -1,20 +1,17 @@
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
from .bddm import BDDMPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .score_sde_vp import ScoreSdeVpPipeline
if is_transformers_available():
from .pipeline_glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusionPipeline
from .glide import GlidePipeline
from .latent_diffusion import LatentDiffusionPipeline
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTSPipeline
from .grad_tts import GradTTSPipeline

View File

@@ -0,0 +1 @@
from .pipeline_bddm import BDDMPipeline, DiffWave

View File

@@ -21,9 +21,9 @@ import torch.nn.functional as F
import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):

View File

@@ -0,0 +1 @@
from .pipeline_ddim import DDIMPipeline

View File

@@ -18,7 +18,7 @@ import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline
class DDIMPipeline(DiffusionPipeline):

View File

@@ -0,0 +1 @@
from .pipeline_ddpm import DDPMPipeline

View File

@@ -18,7 +18,7 @@ import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline
class DDPMPipeline(DiffusionPipeline):

View File

@@ -0,0 +1,5 @@
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_glide import CLIPTextModel, GlidePipeline

View File

@@ -18,7 +18,6 @@ import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
@@ -30,10 +29,10 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
from ..schedulers import DDIMScheduler, DDPMScheduler
from ..utils import logging
from ...models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import logging
#####################
@@ -594,7 +593,7 @@ class CLIPTextTransformer(nn.Module):
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
# causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:

View File

@@ -0,0 +1,6 @@
from ...utils import is_inflect_available, is_transformers_available, is_unidecode_available
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .grad_tts_utils import GradTTSTokenizer
from .pipeline_grad_tts import GradTTSPipeline, TextEncoder

View File

@@ -6,10 +6,10 @@ import torch
from torch import nn
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa

View File

@@ -0,0 +1,5 @@
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_latent_diffusion import AutoencoderKL, LatentDiffusionPipeline, LDMBertModel

View File

@@ -7,20 +7,15 @@ import torch.nn as nn
import torch.utils.checkpoint
import tqdm
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
try:
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
except ImportError:
raise ImportError("Please install the transformers.")
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
################################################################################

View File

@@ -0,0 +1 @@
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline

View File

@@ -6,9 +6,9 @@ import torch.nn as nn
import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
def get_timestep_embedding(timesteps, embedding_dim):

View File

@@ -0,0 +1 @@
from .pipeline_pndm import PNDMPipeline

View File

@@ -18,7 +18,7 @@ import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline
class PNDMPipeline(DiffusionPipeline):

View File

@@ -0,0 +1 @@
from .pipeline_score_sde_ve import ScoreSdeVePipeline

View File

@@ -0,0 +1 @@
from .pipeline_score_sde_vp import ScoreSdeVpPipeline

View File

@@ -47,7 +47,7 @@ from diffusers import (
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave
from diffusers.pipelines.bddm.pipeline_bddm import DiffWave
from diffusers.testing_utils import floats_tensor, slow, torch_device
@@ -1005,11 +1005,13 @@ class PipelineTesterMixin(unittest.TestCase):
bddm = BDDMPipeline(model, noise_scheduler)
# check if the library name for the diffwave moduel is set to pipeline module
self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm")
self.assertTrue(bddm.config["diffwave"][0] == "bddm")
# check if we can save and load the pipeline
with tempfile.TemporaryDirectory() as tmpdirname:
bddm.save_pretrained(tmpdirname)
_ = BDDMPipeline.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname)
bddm = DiffusionPipeline.from_pretrained(tmpdirname)
self.assertTrue(bddm.config["diffwave"][0] == "bddm")