From bf92e746c05b7da59a3ddf233dbc35b2587b3d2f Mon Sep 17 00:00:00 2001 From: gujing <925973396@qq.com> Date: Mon, 4 Dec 2023 12:36:23 +0800 Subject: [PATCH] fix StableDiffusionTensorRT super args error (#6009) --- .../community/stable_diffusion_tensorrt_img2img.py | 13 +++++++++++-- .../community/stable_diffusion_tensorrt_inpaint.py | 13 +++++++++++-- .../community/stable_diffusion_tensorrt_txt2img.py | 13 +++++++++++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py index 041cf3a12d..507177791f 100755 --- a/examples/community/stable_diffusion_tensorrt_img2img.py +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -41,7 +41,7 @@ from polygraphy.backend.trt import ( save_engine, ) from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -709,6 +709,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae", "vae_encoder"], image_height: int = 512, @@ -724,7 +725,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): timing_cache: str = "timing_cache", ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + requires_safety_checker=requires_safety_checker, ) self.vae.forward = self.vae.decode diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py index 71fa1b0a5f..b4e16c7615 100755 --- a/examples/community/stable_diffusion_tensorrt_inpaint.py +++ b/examples/community/stable_diffusion_tensorrt_inpaint.py @@ -41,7 +41,7 @@ from polygraphy.backend.trt import ( save_engine, ) from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -710,6 +710,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae", "vae_encoder"], image_height: int = 512, @@ -725,7 +726,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): timing_cache: str = "timing_cache", ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + requires_safety_checker=requires_safety_checker, ) self.vae.forward = self.vae.decode diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index b51f3176b9..c382614633 100755 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -40,7 +40,7 @@ from polygraphy.backend.trt import ( save_engine, ) from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -624,6 +624,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae"], image_height: int = 768, @@ -639,7 +640,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): timing_cache: str = "timing_cache", ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + requires_safety_checker=requires_safety_checker, ) self.vae.forward = self.vae.decode