diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 576c2cdb4d..24f187b41c 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -80,8 +80,8 @@ class CheckpointMergerPipeline(DiffusionPipeline): alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. + interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None. + Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported. force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. @@ -206,14 +206,18 @@ class CheckpointMergerPipeline(DiffusionPipeline): ) ) checkpoint_path_1 = files[0] if len(files) > 0 else None - if checkpoint_path_2 is not None and os.path.exists(checkpoint_path_2): - files = list( - ( - *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")), - *glob.glob(os.path.join(checkpoint_path_2, "*.bin")), + if len(cached_folders) < 3: + checkpoint_path_2 = None + else: + checkpoint_path_2 = os.path.join(cached_folders[2], attr) + if os.path.exists(checkpoint_path_2): + files = list( + ( + *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")), + *glob.glob(os.path.join(checkpoint_path_2, "*.bin")), + ) ) - ) - checkpoint_path_2 = files[0] if len(files) > 0 else None + checkpoint_path_2 = files[0] if len(files) > 0 else None # For an attr if both checkpoint_path_1 and 2 are None, ignore. # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match. if checkpoint_path_1 is None and checkpoint_path_2 is None: