1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update TensorRT txt2img and inpaint community pipelines (#9037)

* Update TensorRT txt2img and inpaint community pipelines

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* update tensorrt install instructions

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

---------

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
asfiyab-nvidia
2024-08-04 03:30:40 -07:00
committed by GitHub
parent c370b90ff1
commit 3dc10a535f
4 changed files with 417 additions and 152 deletions

View File

@@ -1487,17 +1487,16 @@ NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
```python
import torch
from diffusers import DDIMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines import DiffusionPipeline
# Use the DDIMScheduler scheduler here instead
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
custom_pipeline="stable_diffusion_tensorrt_txt2img",
variant='fp16',
torch_dtype=torch.float16,
scheduler=scheduler,)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
custom_pipeline="stable_diffusion_tensorrt_txt2img",
variant='fp16',
torch_dtype=torch.float16,
scheduler=scheduler,)
# re-use cached folder to save ONNX models and TensorRT Engines
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',)
@@ -2231,12 +2230,12 @@ from io import BytesIO
from PIL import Image
import torch
from diffusers import PNDMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
from diffusers.pipelines import DiffusionPipeline
# Use the PNDMScheduler scheduler here instead
scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
custom_pipeline="stable_diffusion_tensorrt_inpaint",
variant='fp16',
torch_dtype=torch.float16,

View File

@@ -60,7 +60,7 @@ from diffusers.utils import logging
"""
Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt-cu12==10.2.0
python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime
"""
@@ -659,7 +659,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
r"""
Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.
This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:

View File

@@ -18,8 +18,7 @@
import gc
import os
from collections import OrderedDict
from copy import copy
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import onnx
@@ -27,9 +26,11 @@ import onnx_graphsurgeon as gs
import PIL.Image
import tensorrt as trt
import torch
from cuda import cudart
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from packaging import version
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.onnx.loader import fold_constants
@@ -41,24 +42,29 @@ from polygraphy.backend.trt import (
network_from_onnx_path,
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
StableDiffusionInpaintPipeline,
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
prepare_mask_and_masked_image,
retrieve_latents,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
"""
Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt>=8.6.1
python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime
"""
@@ -88,10 +94,6 @@ else:
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def device_view(t):
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
def preprocess_image(image):
"""
image: torch.Tensor
@@ -125,10 +127,8 @@ class Engine:
onnx_path,
fp16,
input_profile=None,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
workspace_size=0,
):
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
@@ -137,20 +137,13 @@ class Engine:
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {}
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
extra_build_args = {}
if not enable_all_tactics:
config_kwargs["tactic_sources"] = []
extra_build_args["tactic_sources"] = []
engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
save_timing_cache=timing_cache,
)
save_engine(engine, path=self.engine_path)
@@ -163,28 +156,24 @@ class Engine:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
for binding in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(binding)
if shape_dict and name in shape_dict:
shape = shape_dict[name]
else:
shape = self.engine.get_binding_shape(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(name, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
self.tensors[name] = tensor
def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items():
assert isinstance(buf, cuda.DeviceView)
device_buffers[name] = buf
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
@@ -325,10 +314,8 @@ def build_engines(
force_engine_rebuild=False,
static_batch=False,
static_shape=True,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
max_workspace_size=0,
):
built_engines = {}
if not os.path.isdir(onnx_dir):
@@ -393,9 +380,7 @@ def build_engines(
static_batch=static_batch,
static_shape=static_shape,
),
enable_preview=enable_preview,
timing_cache=timing_cache,
workspace_size=max_workspace_size,
)
built_engines[model_name] = engine
@@ -674,11 +659,11 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False)
return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
r"""
Pipeline for inpainting using TensorRT accelerated Stable Diffusion.
This model inherits from [`StableDiffusionInpaintPipeline`]. Check the superclass documentation for the generic methods the
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
@@ -702,6 +687,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
def __init__(
self,
vae: AutoencoderKL,
@@ -722,24 +709,86 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
onnx_dir: str = "onnx",
# TensorRT engine build parameters
engine_dir: str = "engine",
build_preview_features: bool = True,
force_engine_rebuild: bool = False,
timing_cache: str = "timing_cache",
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
self.stages = stages
self.image_height, self.image_width = image_height, image_width
self.inpaint = True
@@ -750,7 +799,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
self.timing_cache = timing_cache
self.build_static_batch = False
self.build_dynamic_shape = False
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
@@ -761,6 +809,11 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
self.models = {} # loaded in __loadModels()
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def __loadModels(self):
# Load pipeline models
self.embedding_dim = self.text_encoder.config.hidden_size
@@ -779,6 +832,112 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
if "vae_encoder" in self.stages:
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
return image_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
return_noise=False,
return_image_latents=False,
):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)
if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
else:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
r"""
Runs the safety checker on the given image.
Args:
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
device (torch.device): The device to run the safety checker on.
dtype (torch.dtype): The data type of the input image.
Returns:
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
"""
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -826,7 +985,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape,
enable_preview=self.build_preview_features,
timing_cache=self.timing_cache,
)
@@ -850,9 +1008,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
return tuple(init_images)
def __encode_image(self, init_image):
init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[
"latent"
]
init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"]
init_latents = 0.18215 * init_latents
return init_latents
@@ -881,9 +1037,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
.to(self.torch_device)
)
text_input_ids_inp = device_view(text_input_ids)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
"text_embeddings"
].clone()
@@ -899,8 +1054,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
.input_ids.type(torch.int32)
.to(self.torch_device)
)
uncond_input_ids_inp = device_view(uncond_input_ids)
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
"text_embeddings"
]
@@ -924,18 +1078,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
# Predict the noise residual
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
sample_inp = device_view(latent_model_input)
timestep_inp = device_view(timestep_float)
embeddings_inp = device_view(text_embeddings)
noise_pred = runEngine(
self.engine["unet"],
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
self.stream,
)["latent"]
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
@@ -943,12 +1094,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
return latents
def __decode_latent(self, latents):
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __loadResources(self, image_height, image_width, batch_size):
self.stream = cuda.Stream()
self.stream = cudart.cudaStreamCreate()[1]
# Allocate buffers for TensorRT engine bindings
for model_name, obj in self.models.items():
@@ -1112,5 +1263,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
# VAE decode latent
images = self.__decode_latent(latents)
images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
images = self.numpy_to_pil(images)
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None)
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)

View File

@@ -18,17 +18,19 @@
import gc
import os
from collections import OrderedDict
from copy import copy
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import PIL.Image
import tensorrt as trt
import torch
from cuda import cudart
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from packaging import version
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.onnx.loader import fold_constants
@@ -40,23 +42,25 @@ from polygraphy.backend.trt import (
network_from_onnx_path,
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipeline,
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
"""
Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt>=8.6.1
python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime
"""
@@ -86,10 +90,6 @@ else:
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def device_view(t):
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
class Engine:
def __init__(self, engine_path):
self.engine_path = engine_path
@@ -110,10 +110,8 @@ class Engine:
onnx_path,
fp16,
input_profile=None,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
workspace_size=0,
):
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
@@ -122,20 +120,13 @@ class Engine:
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {}
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
extra_build_args = {}
if not enable_all_tactics:
config_kwargs["tactic_sources"] = []
extra_build_args["tactic_sources"] = []
engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
save_timing_cache=timing_cache,
)
save_engine(engine, path=self.engine_path)
@@ -148,28 +139,24 @@ class Engine:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
for binding in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(binding)
if shape_dict and name in shape_dict:
shape = shape_dict[name]
else:
shape = self.engine.get_binding_shape(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(name, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
self.tensors[name] = tensor
def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items():
assert isinstance(buf, cuda.DeviceView)
device_buffers[name] = buf
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
@@ -310,10 +297,8 @@ def build_engines(
force_engine_rebuild=False,
static_batch=False,
static_shape=True,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
max_workspace_size=0,
):
built_engines = {}
if not os.path.isdir(onnx_dir):
@@ -378,9 +363,7 @@ def build_engines(
static_batch=static_batch,
static_shape=static_shape,
),
enable_preview=enable_preview,
timing_cache=timing_cache,
workspace_size=max_workspace_size,
)
built_engines[model_name] = engine
@@ -588,11 +571,11 @@ def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
class TensorRTStableDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
@@ -616,6 +599,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
@@ -632,28 +617,90 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
image_width: int = 768,
max_batch_size: int = 16,
# ONNX export parameters
onnx_opset: int = 17,
onnx_opset: int = 18,
onnx_dir: str = "onnx",
# TensorRT engine build parameters
engine_dir: str = "engine",
build_preview_features: bool = True,
force_engine_rebuild: bool = False,
timing_cache: str = "timing_cache",
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
self.stages = stages
self.image_height, self.image_width = image_height, image_width
self.inpaint = False
@@ -664,7 +711,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.timing_cache = timing_cache
self.build_static_batch = False
self.build_dynamic_shape = False
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
@@ -675,6 +721,11 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.models = {} # loaded in __loadModels()
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def __loadModels(self):
# Load pipeline models
self.embedding_dim = self.text_encoder.config.hidden_size
@@ -691,6 +742,75 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
if "vae" in self.stages:
self.models["vae"] = make_VAE(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Union[torch.Generator, List[torch.Generator]],
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Prepare the latent vectors for diffusion.
Args:
batch_size (int): The number of samples in the batch.
num_channels_latents (int): The number of channels in the latent vectors.
height (int): The height of the latent vectors.
width (int): The width of the latent vectors.
dtype (torch.dtype): The data type of the latent vectors.
device (torch.device): The device to place the latent vectors on.
generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.
latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.
Returns:
torch.Tensor: The prepared latent vectors.
"""
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
r"""
Runs the safety checker on the given image.
Args:
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
device (torch.device): The device to run the safety checker on.
dtype (torch.dtype): The data type of the input image.
Returns:
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
"""
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -738,7 +858,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape,
enable_preview=self.build_preview_features,
timing_cache=self.timing_cache,
)
@@ -769,9 +888,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
.to(self.torch_device)
)
text_input_ids_inp = device_view(text_input_ids)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
"text_embeddings"
].clone()
@@ -787,8 +905,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
.input_ids.type(torch.int32)
.to(self.torch_device)
)
uncond_input_ids_inp = device_view(uncond_input_ids)
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
"text_embeddings"
]
@@ -812,18 +929,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
# Predict the noise residual
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
sample_inp = device_view(latent_model_input)
timestep_inp = device_view(timestep_float)
embeddings_inp = device_view(text_embeddings)
noise_pred = runEngine(
self.engine["unet"],
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
self.stream,
)["latent"]
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
@@ -831,12 +945,12 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
return latents
def __decode_latent(self, latents):
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __loadResources(self, image_height, image_width, batch_size):
self.stream = cuda.Stream()
self.stream = cudart.cudaStreamCreate()[1]
# Allocate buffers for TensorRT engine bindings
for model_name, obj in self.models.items():