diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 6ff45ac9e8..4fbb6089f3 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -109,12 +108,49 @@ class GlmImageAdaLayerNormZero(nn.Module): ) -class GlmImageAttenProcessorState(Enum): - ImageGen = "ImageGen" - ImageEditWriteKV = "ImageEditWriteKV" - ImageEditReadKV = "ImageEditReadKV" - ImageEditDontReadKV = "ImageEditNoReadKV" +class GlmImageLayerKVCache: + """KV cache for GlmImage model.""" + def __init__(self): + self.k_cache = None + self.v_cache = None + self.mode: Optional[str] = None # "write", "read", "skip" + + def store(self, k: torch.Tensor, v: torch.Tensor): + if self.k_cache is None: + self.k_cache = k + self.v_cache = v + else: + self.k_cache = torch.cat([self.k_cache, k], dim=2) + self.v_cache = torch.cat([self.v_cache, v], dim=2) + def get(self): + return self.k_cache, self.v_cache + + def clear(self): + self.k_cache = None + self.v_cache = None + self.mode = None + + +class GlmImageKVCache: + """Container for all layers' KV caches.""" + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + return self.caches[layer_idx] + + def set_mode(self, mode: Optional[str]): + if mode is not None and mode not in ["write", "read", "skip"]: + raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'") + for cache in self.caches: + cache.mode = mode + + def clear(self): + for cache in self.caches: + cache.clear() class GlmImageAttnProcessor: """ @@ -128,9 +164,6 @@ class GlmImageAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") - self.processor_state = GlmImageAttenProcessorState.ImageGen - self.k_cache = None - self.v_cache = None def __call__( self, @@ -139,6 +172,7 @@ class GlmImageAttnProcessor: encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: dtype = encoder_hidden_states.dtype @@ -172,12 +206,15 @@ class GlmImageAttnProcessor: key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) - if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV: - self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2) - self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2) - elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV: - key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key - value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value + if kv_cache is not None: + if kv_cache.mode == "write": + kv_cache.store(key, value) + elif kv_cache.mode == "read": + k_cache, v_cache = kv_cache.get() + key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key + value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value + elif kv_cache.mode == "skip": + pass # 4. Attention if attention_mask is not None: @@ -246,6 +283,7 @@ class GlmImageTransformerBlock(nn.Module): ] = None, attention_mask: Optional[Dict[str, torch.Tensor]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Timestep conditioning ( @@ -270,6 +308,7 @@ class GlmImageTransformerBlock(nn.Module): encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, + kv_cache=kv_cache, **attention_kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) @@ -464,6 +503,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, + kv_caches: Optional[GlmImageKVCache] = None, image_rotary_emb: Optional[ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] ] = None, @@ -491,7 +531,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach temb = F.silu(temb) # 3. Transformer blocks - for block in self.transformer_blocks: + for idx, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, @@ -501,6 +541,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach image_rotary_emb, attention_mask, attention_kwargs, + kv_caches[idx] if kv_caches is not None else None, ) else: hidden_states, encoder_hidden_states = block( @@ -510,6 +551,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach image_rotary_emb, attention_mask, attention_kwargs, + kv_cache=kv_caches[idx] if kv_caches is not None else None, ) # 4. Output norm & projection @@ -523,12 +565,3 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - - def set_attention_processors_state(self, state: GlmImageAttenProcessorState): - for block in self.transformer_blocks: - block.attn1.processor.processor_state = state - - def clear_attention_processors_cache(self): - for block in self.transformer_blocks: - block.attn1.processor.k_cache = None - block.attn1.processor.v_cache = None diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 0de4e5db54..9c4e05d859 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -27,7 +27,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...loaders import CogView4LoraLoaderMixin from ...models import AutoencoderKL, GlmImageTransformer2DModel -from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState +from ...models.transformers.transformer_glm_image import GlmImageKVCache from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -719,8 +719,10 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): latents=latents, ) + kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) + if image is not None and condition_images_prior_token_id is not None: - self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV) + kv_caches.set_mode("write") latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.latent_channels, 1, 1) @@ -747,6 +749,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): target_size=torch.tensor([condition_image.shape[-2:]], device=device), crop_coords=torch.zeros((1, 2), device=device), attention_kwargs=attention_kwargs, + kv_caches=kv_caches, ) # 6. Prepare additional timestep conditions @@ -796,7 +799,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): timestep = t.expand(latents.shape[0]) - 1 if image is not None: - self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV) + kv_caches.set_mode("read") noise_pred_cond = self.transformer( hidden_states=latent_model_input, @@ -808,14 +811,13 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, return_dict=False, + kv_caches=kv_caches, )[0].float() # perform guidance if self.do_classifier_free_guidance: if image is not None: - self.transformer.set_attention_processors_state( - GlmImageAttenProcessorState.ImageEditDontReadKV - ) + kv_caches.set_mode("skip") noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, @@ -826,6 +828,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, return_dict=False, + kv_caches=kv_caches, )[0].float() noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) @@ -849,8 +852,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): xm.mark_step() self._current_timestep = None - self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen) - self.transformer.clear_attention_processors_cache() + kv_caches.clear() if not output_type == "latent": latents = latents.to(self.vae.dtype)