1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Kolors additional pipelines, community contrib (#11372)

* Kolors additional pipelines, community contrib

---------

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
Teriks
2025-04-23 16:07:27 -05:00
committed by GitHub
parent a4f9c3cbc3
commit b4be42282d
8 changed files with 6517 additions and 6 deletions

View File

@@ -433,7 +433,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -923,7 +923,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416

View File

@@ -469,7 +469,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
class_obj = import_flax_or_no_model(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)

View File

@@ -341,13 +341,13 @@ def get_class_obj_and_candidates(
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)