1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-09-10 18:12:00 +05:30
parent fa926e78f5
commit 1e0856616a

View File

@@ -129,20 +129,21 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
cached_base_pipe_outs = {}
def setUp(self):
self.get_base_pipe_outs()
self._cache_base_pipeline_output()
super().setUp()
def get_base_pipe_outs(self):
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
def _cache_base_pipeline_output(self):
# Get or create the cache on the class (not instance)
if not hasattr(type(self), "cached_base_pipe_outs"):
setattr(type(self), "cached_base_pipe_outs", {})
cached_base_pipe_outs = type(self).cached_base_pipe_outs
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
# Check if all required schedulers are already cached
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
return
cached_base_pipe_outs = cached_base_pipe_outs or {}
for scheduler_cls in self.scheduler_classes:
if scheduler_cls.__name__ in cached_base_pipe_outs:
continue
@@ -156,21 +157,14 @@ class PeftLoraLoaderMixinTests:
# explicitly.
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora})
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
def get_base_pipeline_output(self, scheduler_cls):
"""
Returns the cached base pipeline output for the given scheduler.
Properly handles accessing the class-level cache.
Ensures cache is populated if it hasn't been already.
"""
# Ensure cache is populated
self.get_base_pipe_outs()
cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
return cached_base_pipe_outs[scheduler_cls.__name__]
# Update the class attribute
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
def get_base_pipeline_output(self, scheduler_cls):
self._cache_base_pipeline_output()
return type(self).cached_base_pipe_outs[scheduler_cls.__name__]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs: