mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -129,20 +129,21 @@ class PeftLoraLoaderMixinTests:
|
||||
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
cached_base_pipe_outs = {}
|
||||
|
||||
def setUp(self):
|
||||
self.get_base_pipe_outs()
|
||||
self._cache_base_pipeline_output()
|
||||
super().setUp()
|
||||
|
||||
def get_base_pipe_outs(self):
|
||||
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
|
||||
def _cache_base_pipeline_output(self):
|
||||
# Get or create the cache on the class (not instance)
|
||||
if not hasattr(type(self), "cached_base_pipe_outs"):
|
||||
setattr(type(self), "cached_base_pipe_outs", {})
|
||||
|
||||
cached_base_pipe_outs = type(self).cached_base_pipe_outs
|
||||
|
||||
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
|
||||
# Check if all required schedulers are already cached
|
||||
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
|
||||
return
|
||||
|
||||
cached_base_pipe_outs = cached_base_pipe_outs or {}
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
if scheduler_cls.__name__ in cached_base_pipe_outs:
|
||||
continue
|
||||
@@ -156,21 +157,14 @@ class PeftLoraLoaderMixinTests:
|
||||
# explicitly.
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora})
|
||||
|
||||
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
|
||||
|
||||
def get_base_pipeline_output(self, scheduler_cls):
|
||||
"""
|
||||
Returns the cached base pipeline output for the given scheduler.
|
||||
Properly handles accessing the class-level cache.
|
||||
Ensures cache is populated if it hasn't been already.
|
||||
"""
|
||||
# Ensure cache is populated
|
||||
self.get_base_pipe_outs()
|
||||
cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora
|
||||
|
||||
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
|
||||
return cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
# Update the class attribute
|
||||
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
|
||||
|
||||
def get_base_pipeline_output(self, scheduler_cls):
|
||||
self._cache_base_pipeline_output()
|
||||
return type(self).cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
|
||||
Reference in New Issue
Block a user