From fb9e37adf6925cde627f1b50b910bed32d481719 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jun 2022 15:52:23 +0200 Subject: [PATCH] correct logging --- examples/train_ddpm.py | 2 +- src/diffusers/pipelines/__init__.py | 5 +++++ src/diffusers/pipelines/pipeline_glide.py | 11 +++++++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/train_ddpm.py b/examples/train_ddpm.py index fa13e346f6..6c7333a720 100644 --- a/examples/train_ddpm.py +++ b/examples/train_ddpm.py @@ -9,10 +9,10 @@ from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel from torchvision.transforms import ( + CenterCrop, Compose, InterpolationMode, Lambda, - CenterCrop, RandomHorizontalFlip, Resize, ToTensor, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e7ab7db472..fb5719cdf4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,10 +1,15 @@ from .pipeline_bddm import BDDM from .pipeline_ddim import DDIM from .pipeline_ddpm import DDPM + + try: from .pipeline_glide import GLIDE except ImportError: + class GLIDE: pass + + from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_pndm import PNDM diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index ed378e4ae7..38e2bc54fd 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -15,7 +15,6 @@ """ PyTorch CLIP model.""" import math -import logging from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -25,12 +24,19 @@ import torch.utils.checkpoint from torch import nn import tqdm + + try: from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_utils import PreTrainedModel - from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings + from transformers.utils import ( + ModelOutput, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + ) except: print("Transformers is not installed") pass @@ -38,6 +44,7 @@ except: from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..pipeline_utils import DiffusionPipeline from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler +from ..utils import logging #####################