diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 15ad703029..b7bea38337 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +13,7 @@ # limitations under the License. import inspect -import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from dataclasses import dataclass import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -42,20 +40,7 @@ from ...utils import ( replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from basicsr.utils import tensor2img - -@dataclass -class StableDiffusionAdapterXLPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -213,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -243,6 +229,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L # We'll offload the last model manually. self.final_offload_hook = hook + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.encode_prompt def encode_prompt( self, prompt: str, @@ -452,6 +439,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.check_inputs def check_inputs( self, prompt, @@ -543,6 +531,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl._get_add_time_ids def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): add_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -579,6 +568,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) + # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter._default_height_width def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image @@ -753,16 +743,14 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet - # height = height or self.default_sample_size * self.vae_scale_factor - # width = width or self.default_sample_size * self.vae_scale_factor height, width = self._default_height_width(height, width, image) device = self._execution_device adapter_input = _preprocess_adapter_image(image, height, width).to(device) - original_size = (height, width) #original_size or (height, width) - target_size = (height, width) #target_size or (height, width) + original_size = (height, width) + target_size = (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -795,9 +783,6 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - # text_encoder_lora_scale = ( - # cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - # ) ( prompt_embeds, negative_prompt_embeds, @@ -815,7 +800,6 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - # lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps @@ -925,7 +909,7 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents - return StableDiffusionAdapterXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=image) image = self.image_processor.postprocess(image, output_type=output_type) @@ -936,4 +920,4 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L if not return_dict: return (image,) - return StableDiffusionAdapterXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=image)