1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update checkpoint_merger pipeline to pass the "variant" argument (#6670)

* make checkpoint_merger pipeline pass the "variant" argument to from_pretrained()

* make style

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Lincoln Stein
2024-02-21 20:45:50 -05:00
committed by GitHub
parent 5a54dc9e95
commit d5f444de4b

View File

@@ -81,6 +81,8 @@ class CheckpointMergerPipeline(DiffusionPipeline):
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
variant - which variant of a pretrained model to load, e.g. "fp16" (None)
"""
# Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", None)
@@ -89,6 +91,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
device_map = kwargs.pop("device_map", None)
@@ -173,7 +176,10 @@ class CheckpointMergerPipeline(DiffusionPipeline):
# Step 3:-
# Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
cached_folders[0],
torch_dtype=torch_dtype,
device_map=device_map,
variant=variant,
)
final_pipe.to(self.device)