mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
⚡️ Speed up method AutoencoderKLWan.clear_cache by 886% (#11665)
* ⚡️ Speed up method `AutoencoderKLWan.clear_cache` by 886% **Key optimizations:** - Compute the number of `WanCausalConv3d` modules in each model (`encoder`/`decoder`) **only once during initialization**, store in `self._cached_conv_counts`. This removes unnecessary repeated tree traversals at every `clear_cache` call, which was the main bottleneck (from profiling). - The internal helper `_count_conv3d_fast` is optimized via a generator expression with `sum` for efficiency. All comments from the original code are preserved, except for updated or removed local docstrings/comments relevant to changed lines. **Function signatures and outputs remain unchanged.** * Apply style fixes * Apply suggestions from code review Co-authored-by: Aryan <contact.aryanvs@gmail.com> * Apply style fixes --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: Aryan <contact.aryanvs@gmail.com> Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
This commit is contained in:
@@ -749,6 +749,16 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
|
||||
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
||||
self._cached_conv_counts = {
|
||||
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
||||
if self.decoder is not None
|
||||
else 0,
|
||||
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
||||
if self.encoder is not None
|
||||
else 0,
|
||||
}
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -801,18 +811,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
def clear_cache(self):
|
||||
def _count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, WanCausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
self._conv_num = _count_conv3d(self.decoder)
|
||||
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
||||
self._conv_num = self._cached_conv_counts["decoder"]
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = _count_conv3d(self.encoder)
|
||||
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
Reference in New Issue
Block a user