1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

update tests

This commit is contained in:
anton-l
2022-11-22 01:47:08 +01:00
parent f706729d3c
commit bf8f2fb2c9
9 changed files with 18 additions and 19 deletions

View File

@@ -15,7 +15,6 @@
""" Conversion script for the Versatile Stable Diffusion checkpoints. """
import argparse
import os
from argparse import Namespace
import torch
@@ -32,7 +31,6 @@ from diffusers import (
VersatileDiffusionPipeline,
)
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
from diffusers.pipelines.versatile_diffusion.modeling_gpt2_optimus import GPT2OptimusForLatentConnector
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,

View File

@@ -73,10 +73,10 @@ if is_torch_available() and is_transformers_available():
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
VersatileDiffusionImageToTextPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VersatileDiffusionImageToTextPipeline,
VQDiffusionPipeline,
)
else:

View File

@@ -25,10 +25,10 @@ if is_torch_available() and is_transformers_available():
StableDiffusionPipeline,
)
from .versatile_diffusion import (
VersatileDiffusionImageToTextPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VersatileDiffusionImageToTextPipeline,
)
from .vq_diffusion import VQDiffusionPipeline

View File

@@ -4,8 +4,7 @@ from ...utils import is_torch_available, is_transformers_available
if is_transformers_available() and is_torch_available():
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_image_to_text import VersatileDiffusionImageToTextPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
from .pipeline_versatile_diffusion_image_to_text import VersatileDiffusionImageToTextPipeline

View File

@@ -489,7 +489,9 @@ class ResnetBlockFlat(nn.Module):
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
self.use_in_shortcut = (
self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
)
self.conv_shortcut = None
if self.use_in_shortcut:
@@ -527,7 +529,6 @@ class ResnetBlockFlat(nn.Module):
output_tensor = output_tensor.view(*shape[0:-n_dim], -1)
output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)
print("resblock.output_tensor", output_tensor.abs().sum())
return output_tensor

View File

@@ -9,8 +9,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -23,13 +23,13 @@ import torch.utils.checkpoint
import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection, GPT2Tokenizer
from .modeling_text_unet import UNetFlatConditionModel
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, BaseOutput
from ...pipeline_utils import BaseOutput, DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
from .modeling_text_unet import UNetFlatConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -42,7 +42,8 @@ class TextPipelineOutput(BaseOutput):
Args:
text (`List[str]` or `np.ndarray`)
List of generated text of length `batch_size` or a numpy array of tokens of shape `(batch_size, num_tokens)`.
List of generated text of length `batch_size` or a numpy array of tokens of shape `(batch_size,
num_tokens)`.
"""
text: Union[List[str], np.ndarray]

View File

@@ -20,12 +20,12 @@ import torch.utils.checkpoint
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
from .modeling_text_unet import UNetFlatConditionModel
from ...models import UNet2DConditionModel, AutoencoderKL
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
from .modeling_text_unet import UNetFlatConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -27,15 +27,15 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class VersatileDiffusionImageToTextPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pass
@slow
@require_torch_gpu
class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase):
def test_inference_image_to_text(self):
pipe = VersatileDiffusionImageToTextPipeline.from_pretrained("scripts/vd_official")
pipe = VersatileDiffusionImageToTextPipeline.from_pretrained("diffusers/vd-official-test")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -53,4 +53,4 @@ class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase
assert tokens.shape == (1, 30)
expected_tokens = np.array([0, 1, 2, 3, 4, 5, 6, 7])
assert self.assertItemsEqual(tokens[0] , expected_tokens)
assert self.assertItemsEqual(tokens[0], expected_tokens)