mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -109,12 +108,49 @@ class GlmImageAdaLayerNormZero(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class GlmImageAttenProcessorState(Enum):
|
||||
ImageGen = "ImageGen"
|
||||
ImageEditWriteKV = "ImageEditWriteKV"
|
||||
ImageEditReadKV = "ImageEditReadKV"
|
||||
ImageEditDontReadKV = "ImageEditNoReadKV"
|
||||
class GlmImageLayerKVCache:
|
||||
"""KV cache for GlmImage model."""
|
||||
def __init__(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode: Optional[str] = None # "write", "read", "skip"
|
||||
|
||||
def store(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
self.v_cache = v
|
||||
else:
|
||||
self.k_cache = torch.cat([self.k_cache, k], dim=2)
|
||||
self.v_cache = torch.cat([self.v_cache, v], dim=2)
|
||||
|
||||
def get(self):
|
||||
return self.k_cache, self.v_cache
|
||||
|
||||
def clear(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode = None
|
||||
|
||||
|
||||
class GlmImageKVCache:
|
||||
"""Container for all layers' KV caches."""
|
||||
|
||||
def __init__(self, num_layers: int):
|
||||
self.num_layers = num_layers
|
||||
self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
|
||||
return self.caches[layer_idx]
|
||||
|
||||
def set_mode(self, mode: Optional[str]):
|
||||
if mode is not None and mode not in ["write", "read", "skip"]:
|
||||
raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
|
||||
for cache in self.caches:
|
||||
cache.mode = mode
|
||||
|
||||
def clear(self):
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
class GlmImageAttnProcessor:
|
||||
"""
|
||||
@@ -128,9 +164,6 @@ class GlmImageAttnProcessor:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
self.processor_state = GlmImageAttenProcessorState.ImageGen
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -139,6 +172,7 @@ class GlmImageAttnProcessor:
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = encoder_hidden_states.dtype
|
||||
|
||||
@@ -172,12 +206,15 @@ class GlmImageAttnProcessor:
|
||||
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
|
||||
if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV:
|
||||
self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2)
|
||||
self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2)
|
||||
elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV:
|
||||
key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key
|
||||
value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value
|
||||
if kv_cache is not None:
|
||||
if kv_cache.mode == "write":
|
||||
kv_cache.store(key, value)
|
||||
elif kv_cache.mode == "read":
|
||||
k_cache, v_cache = kv_cache.get()
|
||||
key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key
|
||||
value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value
|
||||
elif kv_cache.mode == "skip":
|
||||
pass
|
||||
|
||||
# 4. Attention
|
||||
if attention_mask is not None:
|
||||
@@ -246,6 +283,7 @@ class GlmImageTransformerBlock(nn.Module):
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
@@ -270,6 +308,7 @@ class GlmImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
kv_cache=kv_cache,
|
||||
**attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
@@ -464,6 +503,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
kv_caches: Optional[GlmImageKVCache] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
@@ -491,7 +531,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
temb = F.silu(temb)
|
||||
|
||||
# 3. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -501,6 +541,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
@@ -510,6 +551,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_cache=kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
@@ -523,12 +565,3 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def set_attention_processors_state(self, state: GlmImageAttenProcessorState):
|
||||
for block in self.transformer_blocks:
|
||||
block.attn1.processor.processor_state = state
|
||||
|
||||
def clear_attention_processors_cache(self):
|
||||
for block in self.transformer_blocks:
|
||||
block.attn1.processor.k_cache = None
|
||||
block.attn1.processor.v_cache = None
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import CogView4LoraLoaderMixin
|
||||
from ...models import AutoencoderKL, GlmImageTransformer2DModel
|
||||
from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState
|
||||
from ...models.transformers.transformer_glm_image import GlmImageKVCache
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
@@ -719,8 +719,10 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
|
||||
|
||||
if image is not None and condition_images_prior_token_id is not None:
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV)
|
||||
kv_caches.set_mode("write")
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
@@ -747,6 +749,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
# 6. Prepare additional timestep conditions
|
||||
@@ -796,7 +799,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if image is not None:
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV)
|
||||
kv_caches.set_mode("read")
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -808,14 +811,13 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if image is not None:
|
||||
self.transformer.set_attention_processors_state(
|
||||
GlmImageAttenProcessorState.ImageEditDontReadKV
|
||||
)
|
||||
kv_caches.set_mode("skip")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
@@ -826,6 +828,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
@@ -849,8 +852,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen)
|
||||
self.transformer.clear_attention_processors_cache()
|
||||
kv_caches.clear()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
|
||||
Reference in New Issue
Block a user