1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Hunyuan] feat: support chunked ff. (#8397)

feat: support chunked ff.
This commit is contained in:
Sayak Paul
2024-06-05 08:12:18 +04:00
committed by GitHub
parent 14f7b545bd
commit a8ad6664c2
2 changed files with 63 additions and 0 deletions

View File

@@ -166,6 +166,7 @@ class HunyuanDiTBlock(nn.Module):
self._chunk_size = None
self._chunk_dim = 0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
@@ -529,3 +530,45 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self):
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, None, 0)

View File

@@ -228,6 +228,26 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
def test_feed_forward_chunking(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_no_chunking = image[0, -3:, -3:, -1]
pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_chunking = image[0, -3:, -3:, -1]
max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
self.assertLess(max_diff, 1e-4)
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()