mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -137,11 +137,12 @@ class PeftLoraLoaderMixinTests:
|
||||
# Get or create the cache on the class (not instance)
|
||||
if not hasattr(type(self), "cached_base_pipe_outs"):
|
||||
setattr(type(self), "cached_base_pipe_outs", {})
|
||||
|
||||
|
||||
cached_base_pipe_outs = type(self).cached_base_pipe_outs
|
||||
|
||||
|
||||
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
|
||||
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
|
||||
__import__("ipdb").set_trace()
|
||||
return
|
||||
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -158,12 +159,15 @@ class PeftLoraLoaderMixinTests:
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora
|
||||
|
||||
|
||||
# Update the class attribute
|
||||
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
|
||||
|
||||
def get_base_pipeline_output(self, scheduler_cls):
|
||||
self._cache_base_pipeline_output()
|
||||
"""
|
||||
Returns the cached base pipeline output for the given scheduler.
|
||||
Cache is populated during setUp, so this just retrieves the value.
|
||||
"""
|
||||
return type(self).cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
|
||||
Reference in New Issue
Block a user