1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-09-11 08:07:04 +05:30
parent 1e0856616a
commit 40f12d2aea

View File

@@ -137,11 +137,12 @@ class PeftLoraLoaderMixinTests:
# Get or create the cache on the class (not instance)
if not hasattr(type(self), "cached_base_pipe_outs"):
setattr(type(self), "cached_base_pipe_outs", {})
cached_base_pipe_outs = type(self).cached_base_pipe_outs
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
__import__("ipdb").set_trace()
return
for scheduler_cls in self.scheduler_classes:
@@ -158,12 +159,15 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora
# Update the class attribute
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
def get_base_pipeline_output(self, scheduler_cls):
self._cache_base_pipeline_output()
"""
Returns the cached base pipeline output for the given scheduler.
Cache is populated during setUp, so this just retrieves the value.
"""
return type(self).cached_base_pipe_outs[scheduler_cls.__name__]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):