1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyixuxu
2026-01-08 03:15:17 +01:00
parent 170d0ba160
commit cfe19a31b9
2 changed files with 68 additions and 33 deletions

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@@ -109,12 +108,49 @@ class GlmImageAdaLayerNormZero(nn.Module):
)
class GlmImageAttenProcessorState(Enum):
ImageGen = "ImageGen"
ImageEditWriteKV = "ImageEditWriteKV"
ImageEditReadKV = "ImageEditReadKV"
ImageEditDontReadKV = "ImageEditNoReadKV"
class GlmImageLayerKVCache:
"""KV cache for GlmImage model."""
def __init__(self):
self.k_cache = None
self.v_cache = None
self.mode: Optional[str] = None # "write", "read", "skip"
def store(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache is None:
self.k_cache = k
self.v_cache = v
else:
self.k_cache = torch.cat([self.k_cache, k], dim=2)
self.v_cache = torch.cat([self.v_cache, v], dim=2)
def get(self):
return self.k_cache, self.v_cache
def clear(self):
self.k_cache = None
self.v_cache = None
self.mode = None
class GlmImageKVCache:
"""Container for all layers' KV caches."""
def __init__(self, num_layers: int):
self.num_layers = num_layers
self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
return self.caches[layer_idx]
def set_mode(self, mode: Optional[str]):
if mode is not None and mode not in ["write", "read", "skip"]:
raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
for cache in self.caches:
cache.mode = mode
def clear(self):
for cache in self.caches:
cache.clear()
class GlmImageAttnProcessor:
"""
@@ -128,9 +164,6 @@ class GlmImageAttnProcessor:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
self.processor_state = GlmImageAttenProcessorState.ImageGen
self.k_cache = None
self.v_cache = None
def __call__(
self,
@@ -139,6 +172,7 @@ class GlmImageAttnProcessor:
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
kv_cache: Optional[GlmImageLayerKVCache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = encoder_hidden_states.dtype
@@ -172,12 +206,15 @@ class GlmImageAttnProcessor:
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV:
self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2)
self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2)
elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV:
key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key
value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value
if kv_cache is not None:
if kv_cache.mode == "write":
kv_cache.store(key, value)
elif kv_cache.mode == "read":
k_cache, v_cache = kv_cache.get()
key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key
value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value
elif kv_cache.mode == "skip":
pass
# 4. Attention
if attention_mask is not None:
@@ -246,6 +283,7 @@ class GlmImageTransformerBlock(nn.Module):
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
kv_cache: Optional[GlmImageLayerKVCache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
@@ -270,6 +308,7 @@ class GlmImageTransformerBlock(nn.Module):
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
kv_cache=kv_cache,
**attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
@@ -464,6 +503,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
kv_caches: Optional[GlmImageKVCache] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
@@ -491,7 +531,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
temb = F.silu(temb)
# 3. Transformer blocks
for block in self.transformer_blocks:
for idx, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
@@ -501,6 +541,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb,
attention_mask,
attention_kwargs,
kv_caches[idx] if kv_caches is not None else None,
)
else:
hidden_states, encoder_hidden_states = block(
@@ -510,6 +551,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb,
attention_mask,
attention_kwargs,
kv_cache=kv_caches[idx] if kv_caches is not None else None,
)
# 4. Output norm & projection
@@ -523,12 +565,3 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def set_attention_processors_state(self, state: GlmImageAttenProcessorState):
for block in self.transformer_blocks:
block.attn1.processor.processor_state = state
def clear_attention_processors_cache(self):
for block in self.transformer_blocks:
block.attn1.processor.k_cache = None
block.attn1.processor.v_cache = None

View File

@@ -27,7 +27,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, GlmImageTransformer2DModel
from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState
from ...models.transformers.transformer_glm_image import GlmImageKVCache
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -719,8 +719,10 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
latents=latents,
)
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
if image is not None and condition_images_prior_token_id is not None:
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV)
kv_caches.set_mode("write")
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
@@ -747,6 +749,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
crop_coords=torch.zeros((1, 2), device=device),
attention_kwargs=attention_kwargs,
kv_caches=kv_caches,
)
# 6. Prepare additional timestep conditions
@@ -796,7 +799,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
timestep = t.expand(latents.shape[0]) - 1
if image is not None:
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV)
kv_caches.set_mode("read")
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
@@ -808,14 +811,13 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
kv_caches=kv_caches,
)[0].float()
# perform guidance
if self.do_classifier_free_guidance:
if image is not None:
self.transformer.set_attention_processors_state(
GlmImageAttenProcessorState.ImageEditDontReadKV
)
kv_caches.set_mode("skip")
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
@@ -826,6 +828,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
kv_caches=kv_caches,
)[0].float()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
@@ -849,8 +852,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
xm.mark_step()
self._current_timestep = None
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen)
self.transformer.clear_attention_processors_cache()
kv_caches.clear()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)