mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update checkpoint_merger pipeline to pass the "variant" argument (#6670)
* make checkpoint_merger pipeline pass the "variant" argument to from_pretrained() * make style --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -81,6 +81,8 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
variant - which variant of a pretrained model to load, e.g. "fp16" (None)
|
||||
|
||||
"""
|
||||
# Default kwargs from DiffusionPipeline
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
@@ -89,6 +91,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
@@ -173,7 +176,10 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
# Step 3:-
|
||||
# Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
|
||||
final_pipe = DiffusionPipeline.from_pretrained(
|
||||
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
|
||||
cached_folders[0],
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
variant=variant,
|
||||
)
|
||||
final_pipe.to(self.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user