mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update tests
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
""" Conversion script for the Versatile Stable Diffusion checkpoints. """
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
@@ -32,7 +31,6 @@ from diffusers import (
|
||||
VersatileDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
|
||||
from diffusers.pipelines.versatile_diffusion.modeling_gpt2_optimus import GPT2OptimusForLatentConnector
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModelWithProjection,
|
||||
|
||||
@@ -73,10 +73,10 @@ if is_torch_available() and is_transformers_available():
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
VersatileDiffusionImageToTextPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VersatileDiffusionImageToTextPipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -25,10 +25,10 @@ if is_torch_available() and is_transformers_available():
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionImageToTextPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VersatileDiffusionImageToTextPipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
|
||||
@@ -4,8 +4,7 @@ from ...utils import is_torch_available, is_transformers_available
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
|
||||
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
|
||||
from .pipeline_versatile_diffusion_image_to_text import VersatileDiffusionImageToTextPipeline
|
||||
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
||||
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
||||
from .pipeline_versatile_diffusion_image_to_text import VersatileDiffusionImageToTextPipeline
|
||||
|
||||
@@ -489,7 +489,9 @@ class ResnetBlockFlat(nn.Module):
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.use_in_shortcut = self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
|
||||
self.use_in_shortcut = (
|
||||
self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
|
||||
)
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
@@ -527,7 +529,6 @@ class ResnetBlockFlat(nn.Module):
|
||||
output_tensor = output_tensor.view(*shape[0:-n_dim], -1)
|
||||
output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)
|
||||
|
||||
print("resblock.output_tensor", output_tensor.abs().sum())
|
||||
return output_tensor
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import logging
|
||||
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
||||
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
||||
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -23,13 +23,13 @@ import torch.utils.checkpoint
|
||||
import PIL
|
||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection, GPT2Tokenizer
|
||||
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention import Transformer2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, BaseOutput
|
||||
from ...pipeline_utils import BaseOutput, DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -42,7 +42,8 @@ class TextPipelineOutput(BaseOutput):
|
||||
|
||||
Args:
|
||||
text (`List[str]` or `np.ndarray`)
|
||||
List of generated text of length `batch_size` or a numpy array of tokens of shape `(batch_size, num_tokens)`.
|
||||
List of generated text of length `batch_size` or a numpy array of tokens of shape `(batch_size,
|
||||
num_tokens)`.
|
||||
"""
|
||||
|
||||
text: Union[List[str], np.ndarray]
|
||||
|
||||
@@ -20,12 +20,12 @@ import torch.utils.checkpoint
|
||||
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
from ...models import UNet2DConditionModel, AutoencoderKL
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention import Transformer2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -27,15 +27,15 @@ from ...test_pipelines_common import PipelineTesterMixin
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class VersatileDiffusionImageToTextPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_image_to_text(self):
|
||||
pipe = VersatileDiffusionImageToTextPipeline.from_pretrained("scripts/vd_official")
|
||||
pipe = VersatileDiffusionImageToTextPipeline.from_pretrained("diffusers/vd-official-test")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -53,4 +53,4 @@ class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase
|
||||
|
||||
assert tokens.shape == (1, 30)
|
||||
expected_tokens = np.array([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
assert self.assertItemsEqual(tokens[0] , expected_tokens)
|
||||
assert self.assertItemsEqual(tokens[0], expected_tokens)
|
||||
|
||||
Reference in New Issue
Block a user