mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update TensorRT txt2img and inpaint community pipelines (#9037)
* Update TensorRT txt2img and inpaint community pipelines Signed-off-by: Asfiya Baig <asfiyab@nvidia.com> * update tensorrt install instructions Signed-off-by: Asfiya Baig <asfiyab@nvidia.com> --------- Signed-off-by: Asfiya Baig <asfiyab@nvidia.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -1487,17 +1487,16 @@ NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DDIMScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
|
||||
# Use the DDIMScheduler scheduler here instead
|
||||
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
custom_pipeline="stable_diffusion_tensorrt_txt2img",
|
||||
variant='fp16',
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=scheduler,)
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
custom_pipeline="stable_diffusion_tensorrt_txt2img",
|
||||
variant='fp16',
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=scheduler,)
|
||||
|
||||
# re-use cached folder to save ONNX models and TensorRT Engines
|
||||
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',)
|
||||
@@ -2231,12 +2230,12 @@ from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
from diffusers import PNDMScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
|
||||
# Use the PNDMScheduler scheduler here instead
|
||||
scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_pipeline="stable_diffusion_tensorrt_inpaint",
|
||||
variant='fp16',
|
||||
torch_dtype=torch.float16,
|
||||
|
||||
@@ -60,7 +60,7 @@ from diffusers.utils import logging
|
||||
"""
|
||||
Installation instructions
|
||||
python3 -m pip install --upgrade transformers diffusers>=0.16.0
|
||||
python3 -m pip install --upgrade tensorrt-cu12==10.2.0
|
||||
python3 -m pip install --upgrade tensorrt~=10.2.0
|
||||
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
|
||||
python3 -m pip install onnxruntime
|
||||
"""
|
||||
@@ -659,7 +659,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.
|
||||
|
||||
This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
|
||||
@@ -18,8 +18,7 @@
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
@@ -27,9 +26,11 @@ import onnx_graphsurgeon as gs
|
||||
import PIL.Image
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from cuda import cudart
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from onnx import shape_inference
|
||||
from packaging import version
|
||||
from polygraphy import cuda
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
from polygraphy.backend.onnx.loader import fold_constants
|
||||
@@ -41,24 +42,29 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
|
||||
prepare_mask_and_masked_image,
|
||||
retrieve_latents,
|
||||
)
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
"""
|
||||
Installation instructions
|
||||
python3 -m pip install --upgrade transformers diffusers>=0.16.0
|
||||
python3 -m pip install --upgrade tensorrt>=8.6.1
|
||||
python3 -m pip install --upgrade tensorrt~=10.2.0
|
||||
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
|
||||
python3 -m pip install onnxruntime
|
||||
"""
|
||||
@@ -88,10 +94,6 @@ else:
|
||||
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
|
||||
|
||||
def device_view(t):
|
||||
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
"""
|
||||
image: torch.Tensor
|
||||
@@ -125,10 +127,8 @@ class Engine:
|
||||
onnx_path,
|
||||
fp16,
|
||||
input_profile=None,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
workspace_size=0,
|
||||
):
|
||||
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
|
||||
p = Profile()
|
||||
@@ -137,20 +137,13 @@ class Engine:
|
||||
assert len(dims) == 3
|
||||
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
||||
|
||||
config_kwargs = {}
|
||||
|
||||
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
|
||||
if enable_preview:
|
||||
# Faster dynamic shapes made optional since it increases engine build time.
|
||||
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
|
||||
if workspace_size > 0:
|
||||
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
|
||||
extra_build_args = {}
|
||||
if not enable_all_tactics:
|
||||
config_kwargs["tactic_sources"] = []
|
||||
extra_build_args["tactic_sources"] = []
|
||||
|
||||
engine = engine_from_network(
|
||||
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
|
||||
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
|
||||
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
|
||||
save_timing_cache=timing_cache,
|
||||
)
|
||||
save_engine(engine, path=self.engine_path)
|
||||
@@ -163,28 +156,24 @@ class Engine:
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
def allocate_buffers(self, shape_dict=None, device="cuda"):
|
||||
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
|
||||
binding = self.engine[idx]
|
||||
if shape_dict and binding in shape_dict:
|
||||
shape = shape_dict[binding]
|
||||
for binding in range(self.engine.num_io_tensors):
|
||||
name = self.engine.get_tensor_name(binding)
|
||||
if shape_dict and name in shape_dict:
|
||||
shape = shape_dict[name]
|
||||
else:
|
||||
shape = self.engine.get_binding_shape(binding)
|
||||
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
||||
if self.engine.binding_is_input(binding):
|
||||
self.context.set_binding_shape(idx, shape)
|
||||
shape = self.engine.get_tensor_shape(name)
|
||||
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
||||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||
self.context.set_input_shape(name, shape)
|
||||
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
|
||||
self.tensors[binding] = tensor
|
||||
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
|
||||
self.tensors[name] = tensor
|
||||
|
||||
def infer(self, feed_dict, stream):
|
||||
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
|
||||
# shallow copy of ordered dict
|
||||
device_buffers = copy(self.buffers)
|
||||
for name, buf in feed_dict.items():
|
||||
assert isinstance(buf, cuda.DeviceView)
|
||||
device_buffers[name] = buf
|
||||
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
|
||||
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
|
||||
self.tensors[name].copy_(buf)
|
||||
for name, tensor in self.tensors.items():
|
||||
self.context.set_tensor_address(name, tensor.data_ptr())
|
||||
noerror = self.context.execute_async_v3(stream)
|
||||
if not noerror:
|
||||
raise ValueError("ERROR: inference failed.")
|
||||
|
||||
@@ -325,10 +314,8 @@ def build_engines(
|
||||
force_engine_rebuild=False,
|
||||
static_batch=False,
|
||||
static_shape=True,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
max_workspace_size=0,
|
||||
):
|
||||
built_engines = {}
|
||||
if not os.path.isdir(onnx_dir):
|
||||
@@ -393,9 +380,7 @@ def build_engines(
|
||||
static_batch=static_batch,
|
||||
static_shape=static_shape,
|
||||
),
|
||||
enable_preview=enable_preview,
|
||||
timing_cache=timing_cache,
|
||||
workspace_size=max_workspace_size,
|
||||
)
|
||||
built_engines[model_name] = engine
|
||||
|
||||
@@ -674,11 +659,11 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False)
|
||||
return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
|
||||
|
||||
|
||||
class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for inpainting using TensorRT accelerated Stable Diffusion.
|
||||
|
||||
This model inherits from [`StableDiffusionInpaintPipeline`]. Check the superclass documentation for the generic methods the
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
@@ -702,6 +687,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
@@ -722,24 +709,86 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
onnx_dir: str = "onnx",
|
||||
# TensorRT engine build parameters
|
||||
engine_dir: str = "engine",
|
||||
build_preview_features: bool = True,
|
||||
force_engine_rebuild: bool = False,
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
|
||||
self.stages = stages
|
||||
self.image_height, self.image_width = image_height, image_width
|
||||
self.inpaint = True
|
||||
@@ -750,7 +799,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
self.timing_cache = timing_cache
|
||||
self.build_static_batch = False
|
||||
self.build_dynamic_shape = False
|
||||
self.build_preview_features = build_preview_features
|
||||
|
||||
self.max_batch_size = max_batch_size
|
||||
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
|
||||
@@ -761,6 +809,11 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
self.models = {} # loaded in __loadModels()
|
||||
self.engine = {} # loaded in build_engines()
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def __loadModels(self):
|
||||
# Load pipeline models
|
||||
self.embedding_dim = self.text_encoder.config.hidden_size
|
||||
@@ -779,6 +832,112 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
if "vae_encoder" in self.stages:
|
||||
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline
|
||||
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = self.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
image=None,
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if (image is None or timestep is None) and not is_strength_max:
|
||||
raise ValueError(
|
||||
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
|
||||
"However, either the image or the noise timestep has not been provided."
|
||||
)
|
||||
|
||||
if return_image_latents or (latents is None and not is_strength_max):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image
|
||||
else:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
if latents is None:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
||||
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
|
||||
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
||||
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
||||
else:
|
||||
noise = latents.to(device)
|
||||
latents = noise * self.scheduler.init_noise_sigma
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(
|
||||
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
|
||||
r"""
|
||||
Runs the safety checker on the given image.
|
||||
Args:
|
||||
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
|
||||
device (torch.device): The device to run the safety checker on.
|
||||
dtype (torch.dtype): The data type of the input image.
|
||||
Returns:
|
||||
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
|
||||
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
|
||||
"""
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
@@ -826,7 +985,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
force_engine_rebuild=self.force_engine_rebuild,
|
||||
static_batch=self.build_static_batch,
|
||||
static_shape=not self.build_dynamic_shape,
|
||||
enable_preview=self.build_preview_features,
|
||||
timing_cache=self.timing_cache,
|
||||
)
|
||||
|
||||
@@ -850,9 +1008,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
return tuple(init_images)
|
||||
|
||||
def __encode_image(self, init_image):
|
||||
init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[
|
||||
"latent"
|
||||
]
|
||||
init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"]
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
@@ -881,9 +1037,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
.to(self.torch_device)
|
||||
)
|
||||
|
||||
text_input_ids_inp = device_view(text_input_ids)
|
||||
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
|
||||
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
|
||||
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
|
||||
"text_embeddings"
|
||||
].clone()
|
||||
|
||||
@@ -899,8 +1054,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
.input_ids.type(torch.int32)
|
||||
.to(self.torch_device)
|
||||
)
|
||||
uncond_input_ids_inp = device_view(uncond_input_ids)
|
||||
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
|
||||
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
|
||||
"text_embeddings"
|
||||
]
|
||||
|
||||
@@ -924,18 +1078,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
# Predict the noise residual
|
||||
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
|
||||
|
||||
sample_inp = device_view(latent_model_input)
|
||||
timestep_inp = device_view(timestep_float)
|
||||
embeddings_inp = device_view(text_embeddings)
|
||||
noise_pred = runEngine(
|
||||
self.engine["unet"],
|
||||
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
|
||||
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
|
||||
self.stream,
|
||||
)["latent"]
|
||||
|
||||
# Perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
|
||||
@@ -943,12 +1094,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
return latents
|
||||
|
||||
def __decode_latent(self, latents):
|
||||
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
|
||||
images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
return images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
def __loadResources(self, image_height, image_width, batch_size):
|
||||
self.stream = cuda.Stream()
|
||||
self.stream = cudart.cudaStreamCreate()[1]
|
||||
|
||||
# Allocate buffers for TensorRT engine bindings
|
||||
for model_name, obj in self.models.items():
|
||||
@@ -1112,5 +1263,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
# VAE decode latent
|
||||
images = self.__decode_latent(latents)
|
||||
|
||||
images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
|
||||
images = self.numpy_to_pil(images)
|
||||
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None)
|
||||
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
@@ -18,17 +18,19 @@
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import PIL.Image
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from cuda import cudart
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from onnx import shape_inference
|
||||
from packaging import version
|
||||
from polygraphy import cuda
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
from polygraphy.backend.onnx.loader import fold_constants
|
||||
@@ -40,23 +42,25 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
"""
|
||||
Installation instructions
|
||||
python3 -m pip install --upgrade transformers diffusers>=0.16.0
|
||||
python3 -m pip install --upgrade tensorrt>=8.6.1
|
||||
python3 -m pip install --upgrade tensorrt~=10.2.0
|
||||
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
|
||||
python3 -m pip install onnxruntime
|
||||
"""
|
||||
@@ -86,10 +90,6 @@ else:
|
||||
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
|
||||
|
||||
def device_view(t):
|
||||
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, engine_path):
|
||||
self.engine_path = engine_path
|
||||
@@ -110,10 +110,8 @@ class Engine:
|
||||
onnx_path,
|
||||
fp16,
|
||||
input_profile=None,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
workspace_size=0,
|
||||
):
|
||||
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
|
||||
p = Profile()
|
||||
@@ -122,20 +120,13 @@ class Engine:
|
||||
assert len(dims) == 3
|
||||
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
||||
|
||||
config_kwargs = {}
|
||||
|
||||
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
|
||||
if enable_preview:
|
||||
# Faster dynamic shapes made optional since it increases engine build time.
|
||||
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
|
||||
if workspace_size > 0:
|
||||
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
|
||||
extra_build_args = {}
|
||||
if not enable_all_tactics:
|
||||
config_kwargs["tactic_sources"] = []
|
||||
extra_build_args["tactic_sources"] = []
|
||||
|
||||
engine = engine_from_network(
|
||||
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
|
||||
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
|
||||
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
|
||||
save_timing_cache=timing_cache,
|
||||
)
|
||||
save_engine(engine, path=self.engine_path)
|
||||
@@ -148,28 +139,24 @@ class Engine:
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
def allocate_buffers(self, shape_dict=None, device="cuda"):
|
||||
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
|
||||
binding = self.engine[idx]
|
||||
if shape_dict and binding in shape_dict:
|
||||
shape = shape_dict[binding]
|
||||
for binding in range(self.engine.num_io_tensors):
|
||||
name = self.engine.get_tensor_name(binding)
|
||||
if shape_dict and name in shape_dict:
|
||||
shape = shape_dict[name]
|
||||
else:
|
||||
shape = self.engine.get_binding_shape(binding)
|
||||
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
||||
if self.engine.binding_is_input(binding):
|
||||
self.context.set_binding_shape(idx, shape)
|
||||
shape = self.engine.get_tensor_shape(name)
|
||||
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
||||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||
self.context.set_input_shape(name, shape)
|
||||
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
|
||||
self.tensors[binding] = tensor
|
||||
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
|
||||
self.tensors[name] = tensor
|
||||
|
||||
def infer(self, feed_dict, stream):
|
||||
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
|
||||
# shallow copy of ordered dict
|
||||
device_buffers = copy(self.buffers)
|
||||
for name, buf in feed_dict.items():
|
||||
assert isinstance(buf, cuda.DeviceView)
|
||||
device_buffers[name] = buf
|
||||
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
|
||||
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
|
||||
self.tensors[name].copy_(buf)
|
||||
for name, tensor in self.tensors.items():
|
||||
self.context.set_tensor_address(name, tensor.data_ptr())
|
||||
noerror = self.context.execute_async_v3(stream)
|
||||
if not noerror:
|
||||
raise ValueError("ERROR: inference failed.")
|
||||
|
||||
@@ -310,10 +297,8 @@ def build_engines(
|
||||
force_engine_rebuild=False,
|
||||
static_batch=False,
|
||||
static_shape=True,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
max_workspace_size=0,
|
||||
):
|
||||
built_engines = {}
|
||||
if not os.path.isdir(onnx_dir):
|
||||
@@ -378,9 +363,7 @@ def build_engines(
|
||||
static_batch=static_batch,
|
||||
static_shape=static_shape,
|
||||
),
|
||||
enable_preview=enable_preview,
|
||||
timing_cache=timing_cache,
|
||||
workspace_size=max_workspace_size,
|
||||
)
|
||||
built_engines[model_name] = engine
|
||||
|
||||
@@ -588,11 +571,11 @@ def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
|
||||
return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
|
||||
|
||||
|
||||
class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
|
||||
|
||||
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
@@ -616,6 +599,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
@@ -632,28 +617,90 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
image_width: int = 768,
|
||||
max_batch_size: int = 16,
|
||||
# ONNX export parameters
|
||||
onnx_opset: int = 17,
|
||||
onnx_opset: int = 18,
|
||||
onnx_dir: str = "onnx",
|
||||
# TensorRT engine build parameters
|
||||
engine_dir: str = "engine",
|
||||
build_preview_features: bool = True,
|
||||
force_engine_rebuild: bool = False,
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
|
||||
self.stages = stages
|
||||
self.image_height, self.image_width = image_height, image_width
|
||||
self.inpaint = False
|
||||
@@ -664,7 +711,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
self.timing_cache = timing_cache
|
||||
self.build_static_batch = False
|
||||
self.build_dynamic_shape = False
|
||||
self.build_preview_features = build_preview_features
|
||||
|
||||
self.max_batch_size = max_batch_size
|
||||
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
|
||||
@@ -675,6 +721,11 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
self.models = {} # loaded in __loadModels()
|
||||
self.engine = {} # loaded in build_engines()
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def __loadModels(self):
|
||||
# Load pipeline models
|
||||
self.embedding_dim = self.text_encoder.config.hidden_size
|
||||
@@ -691,6 +742,75 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
if "vae" in self.stages:
|
||||
self.models["vae"] = make_VAE(self.vae, **models_args)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Union[torch.Generator, List[torch.Generator]],
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Prepare the latent vectors for diffusion.
|
||||
Args:
|
||||
batch_size (int): The number of samples in the batch.
|
||||
num_channels_latents (int): The number of channels in the latent vectors.
|
||||
height (int): The height of the latent vectors.
|
||||
width (int): The width of the latent vectors.
|
||||
dtype (torch.dtype): The data type of the latent vectors.
|
||||
device (torch.device): The device to place the latent vectors on.
|
||||
generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.
|
||||
latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.
|
||||
Returns:
|
||||
torch.Tensor: The prepared latent vectors.
|
||||
"""
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(
|
||||
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
|
||||
r"""
|
||||
Runs the safety checker on the given image.
|
||||
Args:
|
||||
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
|
||||
device (torch.device): The device to run the safety checker on.
|
||||
dtype (torch.dtype): The data type of the input image.
|
||||
Returns:
|
||||
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
|
||||
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
|
||||
"""
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
@@ -738,7 +858,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
force_engine_rebuild=self.force_engine_rebuild,
|
||||
static_batch=self.build_static_batch,
|
||||
static_shape=not self.build_dynamic_shape,
|
||||
enable_preview=self.build_preview_features,
|
||||
timing_cache=self.timing_cache,
|
||||
)
|
||||
|
||||
@@ -769,9 +888,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
.to(self.torch_device)
|
||||
)
|
||||
|
||||
text_input_ids_inp = device_view(text_input_ids)
|
||||
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
|
||||
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
|
||||
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
|
||||
"text_embeddings"
|
||||
].clone()
|
||||
|
||||
@@ -787,8 +905,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
.input_ids.type(torch.int32)
|
||||
.to(self.torch_device)
|
||||
)
|
||||
uncond_input_ids_inp = device_view(uncond_input_ids)
|
||||
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
|
||||
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
|
||||
"text_embeddings"
|
||||
]
|
||||
|
||||
@@ -812,18 +929,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
# Predict the noise residual
|
||||
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
|
||||
|
||||
sample_inp = device_view(latent_model_input)
|
||||
timestep_inp = device_view(timestep_float)
|
||||
embeddings_inp = device_view(text_embeddings)
|
||||
noise_pred = runEngine(
|
||||
self.engine["unet"],
|
||||
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
|
||||
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
|
||||
self.stream,
|
||||
)["latent"]
|
||||
|
||||
# Perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
|
||||
@@ -831,12 +945,12 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
return latents
|
||||
|
||||
def __decode_latent(self, latents):
|
||||
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
|
||||
images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
return images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
def __loadResources(self, image_height, image_width, batch_size):
|
||||
self.stream = cuda.Stream()
|
||||
self.stream = cudart.cudaStreamCreate()[1]
|
||||
|
||||
# Allocate buffers for TensorRT engine bindings
|
||||
for model_name, obj in self.models.items():
|
||||
|
||||
Reference in New Issue
Block a user