mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve memory text to video (#3930)
* Improve memory text to video * Apply suggestions from code review * add test * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * finish test setup --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
b8a5dda56e
commit
332d2bbea3
@@ -119,6 +119,15 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@@ -141,6 +150,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
@@ -171,7 +181,20 @@ class BasicTransformerBlock(nn.Module):
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
@@ -389,6 +389,46 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def enable_forward_chunking(self, chunk_size=None, dim=0):
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
def disable_forward_chunking(self):
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
|
||||
@@ -634,6 +634,9 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Chunk feed-forward computation to save memory
|
||||
self.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -709,6 +709,9 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Chunk feed-forward computation to save memory
|
||||
self.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -399,5 +399,23 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
|
||||
# (todo: sayakpaul) implement SLOW tests.
|
||||
|
||||
Reference in New Issue
Block a user