mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix 3-way merging with the checkpoint_merger community pipeline (#2355)
correctly locate 3rd file; also correct misleading docs
This commit is contained in:
@@ -80,8 +80,8 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.
|
||||
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
@@ -206,14 +206,18 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
)
|
||||
)
|
||||
checkpoint_path_1 = files[0] if len(files) > 0 else None
|
||||
if checkpoint_path_2 is not None and os.path.exists(checkpoint_path_2):
|
||||
files = list(
|
||||
(
|
||||
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
|
||||
*glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
|
||||
if len(cached_folders) < 3:
|
||||
checkpoint_path_2 = None
|
||||
else:
|
||||
checkpoint_path_2 = os.path.join(cached_folders[2], attr)
|
||||
if os.path.exists(checkpoint_path_2):
|
||||
files = list(
|
||||
(
|
||||
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
|
||||
*glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
|
||||
)
|
||||
)
|
||||
)
|
||||
checkpoint_path_2 = files[0] if len(files) > 0 else None
|
||||
checkpoint_path_2 = files[0] if len(files) > 0 else None
|
||||
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
|
||||
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
|
||||
if checkpoint_path_1 is None and checkpoint_path_2 is None:
|
||||
|
||||
Reference in New Issue
Block a user