1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

modify pipeline

This commit is contained in:
Chong
2023-08-23 13:25:15 +08:00
parent 2189e8a4e3
commit e262231ed0

View File

@@ -1,4 +1,4 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +13,7 @@
# limitations under the License.
import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@@ -42,20 +40,7 @@ from ...utils import (
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from basicsr.utils import tensor2img
@dataclass
class StableDiffusionAdapterXLPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -213,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.enable_model_cpu_offload
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
@@ -243,6 +229,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
# We'll offload the last model manually.
self.final_offload_hook = hook
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -452,6 +439,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.check_inputs
def check_inputs(
self,
prompt,
@@ -543,6 +531,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl._get_add_time_ids
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
@@ -579,6 +568,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
# Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter._default_height_width
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
@@ -753,16 +743,14 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
# height = height or self.default_sample_size * self.vae_scale_factor
# width = width or self.default_sample_size * self.vae_scale_factor
height, width = self._default_height_width(height, width, image)
device = self._execution_device
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
original_size = (height, width) #original_size or (height, width)
target_size = (height, width) #target_size or (height, width)
original_size = (height, width)
target_size = (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -795,9 +783,6 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
# text_encoder_lora_scale = (
# cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
# )
(
prompt_embeds,
negative_prompt_embeds,
@@ -815,7 +800,6 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
# lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
@@ -925,7 +909,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
return StableDiffusionAdapterXLPipelineOutput(images=image)
return StableDiffusionXLPipelineOutput(images=image)
image = self.image_processor.postprocess(image, output_type=output_type)
@@ -936,4 +920,4 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
if not return_dict:
return (image,)
return StableDiffusionAdapterXLPipelineOutput(images=image)
return StableDiffusionXLPipelineOutput(images=image)