mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[ONNX] Improve ONNXPipeline scheduler compatibility, fix safety_checker (#1173)
* [ONNX] Improve ONNX scheduler compatibility, fix safety_checker * typo
This commit is contained in:
@@ -81,6 +81,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
output_path = Path(output_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
||||
text_input = pipeline.tokenizer(
|
||||
"A sample prompt",
|
||||
padding="max_length",
|
||||
@@ -103,13 +105,15 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
del pipeline.text_encoder
|
||||
|
||||
# UNET
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
unet_path = output_path / "unet" / "model.onnx"
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
|
||||
torch.LongTensor([0, 1]).to(device=device),
|
||||
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
|
||||
torch.randn(2).to(device=device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
@@ -142,11 +146,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
@@ -158,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
@@ -174,24 +188,35 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
del pipeline.vae
|
||||
|
||||
# SAFETY CHECKER
|
||||
safety_checker = pipeline.safety_checker
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
|
||||
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
if pipeline.safety_checker is not None:
|
||||
safety_checker = pipeline.safety_checker
|
||||
clip_num_channels = safety_checker.config.vision_config.num_channels
|
||||
clip_image_size = safety_checker.config.vision_config.image_size
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(
|
||||
1,
|
||||
clip_num_channels,
|
||||
clip_image_size,
|
||||
clip_image_size,
|
||||
).to(device=device, dtype=dtype),
|
||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
|
||||
else:
|
||||
safety_checker = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||
@@ -200,7 +225,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=pipeline.feature_extractor,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user