mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
modify pipeline
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
@@ -42,20 +40,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from basicsr.utils import tensor2img
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionAdapterXLPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -213,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
@@ -243,6 +229,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -452,6 +439,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -543,6 +531,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
@@ -579,6 +568,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter._default_height_width
|
||||
def _default_height_width(self, height, width, image):
|
||||
# NOTE: It is possible that a list of images have different
|
||||
# dimensions for each image, so just checking the first image
|
||||
@@ -753,16 +743,14 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
# height = height or self.default_sample_size * self.vae_scale_factor
|
||||
# width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
device = self._execution_device
|
||||
|
||||
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
|
||||
|
||||
original_size = (height, width) #original_size or (height, width)
|
||||
target_size = (height, width) #target_size or (height, width)
|
||||
original_size = (height, width)
|
||||
target_size = (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -795,9 +783,6 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
# text_encoder_lora_scale = (
|
||||
# cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
# )
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
@@ -815,7 +800,6 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
# lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
@@ -925,7 +909,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionAdapterXLPipelineOutput(images=image)
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
@@ -936,4 +920,4 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionAdapterXLPipelineOutput(images=image)
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
Reference in New Issue
Block a user