mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* init * encode with glm * draft schedule * feat(scheduler): Add CogView scheduler implementation * feat(embeddings): add CogView 2D rotary positional embedding * 1 * Update pipeline_cogview4.py * fix the timestep init and sigma * update latent * draft patch(not work) * fix * [WIP][cogview4]: implement initial CogView4 pipeline Implement the basic CogView4 pipeline structure with the following changes: - Add CogView4 pipeline implementation - Implement DDIM scheduler for CogView4 - Add CogView3Plus transformer architecture - Update embedding models Current limitations: - CFG implementation uses padding for sequence length alignment - Need to verify transformer inference alignment with Megatron TODO: - Consider separate forward passes for condition/uncondition instead of padding approach * [WIP][cogview4][refactor]: Split condition/uncondition forward pass in CogView4 pipeline Split the forward pass for conditional and unconditional predictions in the CogView4 pipeline to match the original implementation. The noise prediction is now done separately for each case before combining them for guidance. However, the results still need improvement. This is a work in progress as the generated images are not yet matching expected quality. * use with -2 hidden state * remove text_projector * 1 * [WIP] Add tensor-reload to align input from transformer block * [WIP] for older glm * use with cogview4 transformers forward twice of u and uc * Update convert_cogview4_to_diffusers.py * remove this * use main example * change back * reset * setback * back * back 4 * Fix qkv conversion logic for CogView4 to Diffusers format * back5 * revert to sat to cogview4 version * update a new convert from megatron * [WIP][cogview4]: implement CogView4 attention processor Add CogView4AttnProcessor class for implementing scaled dot-product attention with rotary embeddings for the CogVideoX model. This processor concatenates encoder and hidden states, applies QKV projections and RoPE, but does not include spatial normalization. TODO: - Fix incorrect QKV projection weights - Resolve ~25% error in RoPE implementation compared to Megatron * [cogview4] implement CogView4 transformer block Implement CogView4 transformer block following the Megatron architecture: - Add multi-modulate and multi-gate mechanisms for adaptive layer normalization - Implement dual-stream attention with encoder-decoder structure - Add feed-forward network with GELU activation - Support rotary position embeddings for image tokens The implementation follows the original CogView4 architecture while adapting it to work within the diffusers framework. * with new attn * [bugfix] fix dimension mismatch in CogView4 attention * [cogview4][WIP]: update final normalization in CogView4 transformer Refactored the final normalization layer in CogView4 transformer to use separate layernorm and AdaLN operations instead of combined AdaLayerNormContinuous. This matches the original implementation but needs validation. Needs verification against reference implementation. * 1 * put back * Update transformer_cogview4.py * change time_shift * Update pipeline_cogview4.py * change timesteps * fix * change text_encoder_id * [cogview4][rope] align RoPE implementation with Megatron - Implement apply_rope method in attention processor to match Megatron's implementation - Update position embeddings to ensure compatibility with Megatron-style rotary embeddings - Ensure consistent rotary position encoding across attention layers This change improves compatibility with Megatron-based models and provides better alignment with the original implementation's positional encoding approach. * [cogview4][bugfix] apply silu activation to time embeddings in CogView4 Applied silu activation to time embeddings before splitting into conditional and unconditional parts in CogView4Transformer2DModel. This matches the original implementation and helps ensure correct time conditioning behavior. * [cogview4][chore] clean up pipeline code - Remove commented out code and debug statements - Remove unused retrieve_timesteps function - Clean up code formatting and documentation This commit focuses on code cleanup in the CogView4 pipeline implementation, removing unnecessary commented code and improving readability without changing functionality. * [cogview4][scheduler] Implement CogView4 scheduler and pipeline * now It work * add timestep * batch * change convert scipt * refactor pt. 1; make style * refactor pt. 2 * refactor pt. 3 * add tests * make fix-copies * update toctree.yml * use flow match scheduler instead of custom * remove scheduling_cogview.py * add tiktoken to test dependencies * Update src/diffusers/models/embeddings.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * apply suggestions from review * use diffusers apply_rotary_emb * update flow match scheduler to accept timesteps * fix comment * apply review sugestions * Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py Co-authored-by: YiYi Xu <yixu310@gmail.com> --------- Co-authored-by: 三洋三洋 <1258009915@qq.com> Co-authored-by: OleehyO <leehy0357@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: YiYi Xu <yixu310@gmail.com>
235 lines
8.6 KiB
Python
235 lines
8.6 KiB
Python
# Copyright 2024 The HuggingFace Team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM
|
|
|
|
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
|
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
|
|
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
|
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
|
|
|
|
|
enable_full_determinism()
|
|
|
|
|
|
class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|
pipeline_class = CogView4Pipeline
|
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
|
required_optional_params = frozenset(
|
|
[
|
|
"num_inference_steps",
|
|
"generator",
|
|
"latents",
|
|
"return_dict",
|
|
"callback_on_step_end",
|
|
"callback_on_step_end_tensor_inputs",
|
|
]
|
|
)
|
|
|
|
supports_dduf = False
|
|
test_xformers_attention = False
|
|
test_layerwise_casting = True
|
|
|
|
def get_dummy_components(self):
|
|
torch.manual_seed(0)
|
|
transformer = CogView4Transformer2DModel(
|
|
patch_size=2,
|
|
in_channels=4,
|
|
num_layers=2,
|
|
attention_head_dim=4,
|
|
num_attention_heads=4,
|
|
out_channels=4,
|
|
text_embed_dim=32,
|
|
time_embed_dim=8,
|
|
condition_dim=4,
|
|
)
|
|
|
|
torch.manual_seed(0)
|
|
vae = AutoencoderKL(
|
|
block_out_channels=[32, 64],
|
|
in_channels=3,
|
|
out_channels=3,
|
|
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
|
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
|
latent_channels=4,
|
|
sample_size=128,
|
|
)
|
|
|
|
torch.manual_seed(0)
|
|
scheduler = FlowMatchEulerDiscreteScheduler(
|
|
base_shift=0.25,
|
|
max_shift=0.75,
|
|
base_image_seq_len=256,
|
|
use_dynamic_shifting=True,
|
|
time_shift_type="linear",
|
|
)
|
|
|
|
torch.manual_seed(0)
|
|
text_encoder_config = GlmConfig(
|
|
hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
|
|
)
|
|
text_encoder = GlmForCausalLM(text_encoder_config)
|
|
# TODO(aryan): change this to THUDM/CogView4 once released
|
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
|
|
|
|
components = {
|
|
"transformer": transformer,
|
|
"vae": vae,
|
|
"scheduler": scheduler,
|
|
"text_encoder": text_encoder,
|
|
"tokenizer": tokenizer,
|
|
}
|
|
return components
|
|
|
|
def get_dummy_inputs(self, device, seed=0):
|
|
if str(device).startswith("mps"):
|
|
generator = torch.manual_seed(seed)
|
|
else:
|
|
generator = torch.Generator(device=device).manual_seed(seed)
|
|
inputs = {
|
|
"prompt": "dance monkey",
|
|
"negative_prompt": "",
|
|
"generator": generator,
|
|
"num_inference_steps": 2,
|
|
"guidance_scale": 6.0,
|
|
"height": 16,
|
|
"width": 16,
|
|
"max_sequence_length": 16,
|
|
"output_type": "pt",
|
|
}
|
|
return inputs
|
|
|
|
def test_inference(self):
|
|
device = "cpu"
|
|
|
|
components = self.get_dummy_components()
|
|
pipe = self.pipeline_class(**components)
|
|
pipe.to(device)
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
inputs = self.get_dummy_inputs(device)
|
|
image = pipe(**inputs)[0]
|
|
generated_image = image[0]
|
|
|
|
self.assertEqual(generated_image.shape, (3, 16, 16))
|
|
expected_image = torch.randn(3, 16, 16)
|
|
max_diff = np.abs(generated_image - expected_image).max()
|
|
self.assertLessEqual(max_diff, 1e10)
|
|
|
|
def test_callback_inputs(self):
|
|
sig = inspect.signature(self.pipeline_class.__call__)
|
|
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
|
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
|
|
|
if not (has_callback_tensor_inputs and has_callback_step_end):
|
|
return
|
|
|
|
components = self.get_dummy_components()
|
|
pipe = self.pipeline_class(**components)
|
|
pipe = pipe.to(torch_device)
|
|
pipe.set_progress_bar_config(disable=None)
|
|
self.assertTrue(
|
|
hasattr(pipe, "_callback_tensor_inputs"),
|
|
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
|
)
|
|
|
|
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
|
# iterate over callback args
|
|
for tensor_name, tensor_value in callback_kwargs.items():
|
|
# check that we're only passing in allowed tensor inputs
|
|
assert tensor_name in pipe._callback_tensor_inputs
|
|
|
|
return callback_kwargs
|
|
|
|
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
|
for tensor_name in pipe._callback_tensor_inputs:
|
|
assert tensor_name in callback_kwargs
|
|
|
|
# iterate over callback args
|
|
for tensor_name, tensor_value in callback_kwargs.items():
|
|
# check that we're only passing in allowed tensor inputs
|
|
assert tensor_name in pipe._callback_tensor_inputs
|
|
|
|
return callback_kwargs
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
# Test passing in a subset
|
|
inputs["callback_on_step_end"] = callback_inputs_subset
|
|
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
|
output = pipe(**inputs)[0]
|
|
|
|
# Test passing in a everything
|
|
inputs["callback_on_step_end"] = callback_inputs_all
|
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
|
output = pipe(**inputs)[0]
|
|
|
|
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
|
is_last = i == (pipe.num_timesteps - 1)
|
|
if is_last:
|
|
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
|
return callback_kwargs
|
|
|
|
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
|
output = pipe(**inputs)[0]
|
|
assert output.abs().sum() < 1e10
|
|
|
|
def test_inference_batch_single_identical(self):
|
|
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
|
|
|
def test_attention_slicing_forward_pass(
|
|
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
|
):
|
|
if not self.test_attention_slicing:
|
|
return
|
|
|
|
components = self.get_dummy_components()
|
|
pipe = self.pipeline_class(**components)
|
|
for component in pipe.components.values():
|
|
if hasattr(component, "set_default_attn_processor"):
|
|
component.set_default_attn_processor()
|
|
pipe.to(torch_device)
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
generator_device = "cpu"
|
|
inputs = self.get_dummy_inputs(generator_device)
|
|
output_without_slicing = pipe(**inputs)[0]
|
|
|
|
pipe.enable_attention_slicing(slice_size=1)
|
|
inputs = self.get_dummy_inputs(generator_device)
|
|
output_with_slicing1 = pipe(**inputs)[0]
|
|
|
|
pipe.enable_attention_slicing(slice_size=2)
|
|
inputs = self.get_dummy_inputs(generator_device)
|
|
output_with_slicing2 = pipe(**inputs)[0]
|
|
|
|
if test_max_difference:
|
|
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
|
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
|
self.assertLess(
|
|
max(max_diff1, max_diff2),
|
|
expected_max_diff,
|
|
"Attention slicing should not affect the inference results",
|
|
)
|