diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 218ac87fe5..b9cb1463e3 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -81,6 +81,8 @@ class CheckpointMergerPipeline(DiffusionPipeline): force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + variant - which variant of a pretrained model to load, e.g. "fp16" (None) + """ # Default kwargs from DiffusionPipeline cache_dir = kwargs.pop("cache_dir", None) @@ -89,6 +91,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) token = kwargs.pop("token", None) + variant = kwargs.pop("variant", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) device_map = kwargs.pop("device_map", None) @@ -173,7 +176,10 @@ class CheckpointMergerPipeline(DiffusionPipeline): # Step 3:- # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place final_pipe = DiffusionPipeline.from_pretrained( - cached_folders[0], torch_dtype=torch_dtype, device_map=device_map + cached_folders[0], + torch_dtype=torch_dtype, + device_map=device_map, + variant=variant, ) final_pipe.to(self.device)