mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Cosmos (#10660)
* begin transformer conversion * refactor * refactor * refactor * refactor * refactor * refactor * update * add conversion script * add pipeline * make fix-copies * remove einops * update docs * gradient checkpointing * add transformer test * update * debug * remove prints * match sigmas * add vae pt. 1 * finish CV* vae * update * update * update * update * update * update * make fix-copies * update * make fix-copies * fix * update * update * make fix-copies * update * update tests * handle device and dtype for safety checker; required in latest diffusers * remove enable_gqa and use repeat_interleave instead * enforce safety checker; use dummy checker in fast tests * add review suggestion for ONNX export Co-Authored-By: Asfiya Baig <asfiyab@nvidia.com> * fix safety_checker issues when not passed explicitly We could either do what's done in this commit, or update the Cosmos examples to explicitly pass the safety checker * use cosmos guardrail package * auto format docs * update conversion script to support 14B models * update name CosmosPipeline -> CosmosTextToWorldPipeline * update docs * fix docs * fix group offload test failing for vae --------- Co-authored-by: Asfiya Baig <asfiyab@nvidia.com>
This commit is contained in:
@@ -295,6 +295,8 @@
|
||||
title: CogView4Transformer2DModel
|
||||
- local: api/models/consisid_transformer3d
|
||||
title: ConsisIDTransformer3DModel
|
||||
- local: api/models/cosmos_transformer3d
|
||||
title: CosmosTransformer3DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
@@ -363,6 +365,8 @@
|
||||
title: AutoencoderKLAllegro
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoderkl_cosmos
|
||||
title: AutoencoderKLCosmos
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
@@ -433,6 +437,8 @@
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
|
||||
40
docs/source/en/api/models/autoencoderkl_cosmos.md
Normal file
40
docs/source/en/api/models/autoencoderkl_cosmos.md
Normal file
@@ -0,0 +1,40 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLCosmos
|
||||
|
||||
[Cosmos Tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer).
|
||||
|
||||
Supported models:
|
||||
- [nvidia/Cosmos-1.0-Tokenizer-CV8x8x8](https://huggingface.co/nvidia/Cosmos-1.0-Tokenizer-CV8x8x8)
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLCosmos
|
||||
|
||||
vae = AutoencoderKLCosmos.from_pretrained("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8", subfolder="vae")
|
||||
```
|
||||
|
||||
## AutoencoderKLCosmos
|
||||
|
||||
[[autodoc]] AutoencoderKLCosmos
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
30
docs/source/en/api/models/cosmos_transformer3d.md
Normal file
30
docs/source/en/api/models/cosmos_transformer3d.md
Normal file
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# CosmosTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import CosmosTransformer3DModel
|
||||
|
||||
transformer = CosmosTransformer3DModel.from_pretrained("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## CosmosTransformer3DModel
|
||||
|
||||
[[autodoc]] CosmosTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
41
docs/source/en/api/pipelines/cosmos.md
Normal file
41
docs/source/en/api/pipelines/cosmos.md
Normal file
@@ -0,0 +1,41 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Cosmos
|
||||
|
||||
[Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
|
||||
|
||||
*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## CosmosTextToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosTextToWorldPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosVideoToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosVideoToWorldPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
352
scripts/convert_cosmos_to_diffusers.py
Normal file
352
scripts/convert_cosmos_to_diffusers.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
|
||||
block_index = int(key.split(".")[1].removeprefix("block"))
|
||||
new_key = key
|
||||
|
||||
old_prefix = f"blocks.block{block_index}"
|
||||
new_prefix = f"transformer_blocks.{block_index}"
|
||||
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
||||
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"affline_norm": "time_embed.norm",
|
||||
".blocks.0.block.attn": ".attn1",
|
||||
".blocks.1.block.attn": ".attn2",
|
||||
".blocks.2.block": ".ff",
|
||||
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
||||
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
||||
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
||||
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
||||
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
||||
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
||||
"to_q.0": "to_q",
|
||||
"to_q.1": "norm_q",
|
||||
"to_k.0": "to_k",
|
||||
"to_k.1": "norm_k",
|
||||
"to_v.0": "to_v",
|
||||
"layer1": "net.0.proj",
|
||||
"layer2": "net.2",
|
||||
"proj.1": "proj",
|
||||
"x_embedder": "patch_embed",
|
||||
"extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"blocks.block": rename_transformer_blocks_,
|
||||
"logvar.0.freqs": remove_keys_,
|
||||
"logvar.0.phases": remove_keys_,
|
||||
"logvar.1.weight": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
}
|
||||
|
||||
TRANSFORMER_CONFIGS = {
|
||||
"Cosmos-1.0-Diffusion-7B-Text2World": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 1.0, 1.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
"Cosmos-1.0-Diffusion-7B-Video2World": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 1.0, 1.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
"Cosmos-1.0-Diffusion-14B-Text2World": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 2.0, 2.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
"Cosmos-1.0-Diffusion-14B-Video2World": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 2.0, 2.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
"down.0": "down_blocks.0",
|
||||
"down.1": "down_blocks.1",
|
||||
"down.2": "down_blocks.2",
|
||||
"up.0": "up_blocks.2",
|
||||
"up.1": "up_blocks.1",
|
||||
"up.2": "up_blocks.0",
|
||||
".block.": ".resnets.",
|
||||
"downsample": "downsamplers.0",
|
||||
"upsample": "upsamplers.0",
|
||||
"mid.block_1": "mid_block.resnets.0",
|
||||
"mid.attn_1.0": "mid_block.attentions.0",
|
||||
"mid.attn_1.1": "mid_block.temp_attentions.0",
|
||||
"mid.block_2": "mid_block.resnets.1",
|
||||
".q.conv3d": ".to_q",
|
||||
".k.conv3d": ".to_k",
|
||||
".v.conv3d": ".to_v",
|
||||
".proj_out.conv3d": ".to_out.0",
|
||||
".0.conv3d": ".conv_s",
|
||||
".1.conv3d": ".conv_t",
|
||||
"conv1.conv3d": "conv1",
|
||||
"conv2.conv3d": "conv2",
|
||||
"conv3.conv3d": "conv3",
|
||||
"nin_shortcut.conv3d": "conv_shortcut",
|
||||
"quant_conv.conv3d": "quant_conv",
|
||||
"post_quant_conv.conv3d": "post_quant_conv",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"wavelets": remove_keys_,
|
||||
"_arange": remove_keys_,
|
||||
"patch_size_buffer": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_CONFIGS = {
|
||||
"CV8x8x8-0.1": {
|
||||
"name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 16,
|
||||
"encoder_block_out_channels": (128, 256, 512, 512),
|
||||
"decode_block_out_channels": (256, 512, 512, 512),
|
||||
"attention_resolutions": (32,),
|
||||
"resolution": 1024,
|
||||
"num_layers": 2,
|
||||
"patch_size": 4,
|
||||
"patch_type": "haar",
|
||||
"scaling_factor": 1.0,
|
||||
"spatial_compression_ratio": 8,
|
||||
"temporal_compression_ratio": 8,
|
||||
"latents_mean": None,
|
||||
"latents_std": None,
|
||||
},
|
||||
},
|
||||
"CV8x8x8-1.0": {
|
||||
"name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 16,
|
||||
"encoder_block_out_channels": (128, 256, 512, 512),
|
||||
"decode_block_out_channels": (256, 512, 512, 512),
|
||||
"attention_resolutions": (32,),
|
||||
"resolution": 1024,
|
||||
"num_layers": 2,
|
||||
"patch_size": 4,
|
||||
"patch_type": "haar",
|
||||
"scaling_factor": 1.0,
|
||||
"spatial_compression_ratio": 8,
|
||||
"temporal_compression_ratio": 8,
|
||||
"latents_mean": None,
|
||||
"latents_std": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_transformer(transformer_type: str, ckpt_path: str):
|
||||
PREFIX_KEY = "net."
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
|
||||
with init_empty_weights():
|
||||
config = TRANSFORMER_CONFIGS[transformer_type]
|
||||
transformer = CosmosTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(vae_type: str):
|
||||
model_name = VAE_CONFIGS[vae_type]["name"]
|
||||
snapshot_directory = snapshot_download(model_name, repo_type="model")
|
||||
directory = pathlib.Path(snapshot_directory)
|
||||
|
||||
autoencoder_file = directory / "autoencoder.jit"
|
||||
mean_std_file = directory / "mean_std.pt"
|
||||
|
||||
original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
|
||||
if mean_std_file.exists():
|
||||
mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
mean_std = (None, None)
|
||||
|
||||
config = VAE_CONFIGS[vae_type]["diffusers_config"]
|
||||
config.update(
|
||||
{
|
||||
"latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
|
||||
"latents_std": mean_std[1].detach().cpu().numpy().tolist(),
|
||||
}
|
||||
)
|
||||
vae = AutoencoderKLCosmos(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
|
||||
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None
|
||||
assert args.vae_type is not None
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.vae_type is not None:
|
||||
vae = convert_vae(args.vae_type)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
|
||||
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
|
||||
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
|
||||
# So, the sigma_min values that is used is the default value of 0.002.
|
||||
scheduler = EDMEulerScheduler(
|
||||
sigma_min=0.002,
|
||||
sigma_max=80,
|
||||
sigma_data=0.5,
|
||||
sigma_schedule="karras",
|
||||
num_train_timesteps=1000,
|
||||
prediction_type="epsilon",
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
|
||||
pipe = CosmosTextToWorldPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -148,6 +148,7 @@ else:
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLCosmos",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
@@ -166,6 +167,7 @@ else:
|
||||
"ControlNetModel",
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
"CosmosTransformer3DModel",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"FluxControlNetModel",
|
||||
@@ -357,6 +359,9 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
"CosmosVideoToWorldPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
@@ -745,6 +750,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -763,6 +769,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxControlNetModel,
|
||||
@@ -933,6 +940,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
CosmosVideoToWorldPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
|
||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
@@ -75,6 +76,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
@@ -114,6 +116,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -151,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
|
||||
@@ -203,8 +203,8 @@ class Attention(nn.Module):
|
||||
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "rms_norm_across_heads":
|
||||
# LTX applies qk norm across all heads
|
||||
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
@@ -3,6 +3,7 @@ from .autoencoder_dc import AutoencoderDC
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
|
||||
1106
src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
Normal file
1106
src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -744,6 +744,17 @@ class DiagonalGaussianDistribution(object):
|
||||
return self.mean
|
||||
|
||||
|
||||
class IdentityDistribution(object):
|
||||
def __init__(self, parameters: torch.Tensor):
|
||||
self.parameters = parameters
|
||||
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||
return self.parameters
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
return self.parameters
|
||||
|
||||
|
||||
class EncoderTiny(nn.Module):
|
||||
r"""
|
||||
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
|
||||
|
||||
@@ -1204,7 +1204,7 @@ def apply_rotary_emb(
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen and CogView4
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
|
||||
@@ -19,6 +19,7 @@ if is_torch_available():
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
|
||||
551
src/diffusers/models/transformers/transformer_cosmos.py
Normal file
551
src/diffusers/models/transformers/transformer_cosmos.py
Normal file
@@ -0,0 +1,551 @@
|
||||
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
class CosmosPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * patch_size[2], out_channels, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, num_channels, num_frames // p_t, p_t, height // p_h, p_h, width // p_w, p_w
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosTimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_features, out_features, bias=False)
|
||||
self.activation = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear_1(timesteps)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class CosmosEmbedding(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
|
||||
self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim)
|
||||
self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
|
||||
temb = self.t_embedder(timesteps_proj)
|
||||
embedded_timestep = self.norm(timesteps_proj)
|
||||
return temb, embedded_timestep
|
||||
|
||||
|
||||
class CosmosAdaLayerNorm(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_features: int) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = in_features
|
||||
|
||||
self.activation = nn.SiLU()
|
||||
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
|
||||
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
|
||||
self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, embedded_timestep: torch.Tensor, temb: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
embedded_timestep = self.activation(embedded_timestep)
|
||||
embedded_timestep = self.linear_1(embedded_timestep)
|
||||
embedded_timestep = self.linear_2(embedded_timestep)
|
||||
|
||||
if temb is not None:
|
||||
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim]
|
||||
|
||||
shift, scale = embedded_timestep.chunk(2, dim=1)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_features: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
|
||||
self.activation = nn.SiLU()
|
||||
|
||||
if hidden_features is None:
|
||||
self.linear_1 = nn.Identity()
|
||||
else:
|
||||
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
|
||||
|
||||
self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embedded_timestep: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
embedded_timestep = self.activation(embedded_timestep)
|
||||
embedded_timestep = self.linear_1(embedded_timestep)
|
||||
embedded_timestep = self.linear_2(embedded_timestep)
|
||||
|
||||
if temb is not None:
|
||||
embedded_timestep = embedded_timestep + temb
|
||||
|
||||
shift, scale, gate = embedded_timestep.chunk(3, dim=1)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return hidden_states, gate
|
||||
|
||||
|
||||
class CosmosAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# 1. QKV projections
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Apply RoPE
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
|
||||
# 4. Prepare for GQA
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
key = key.repeat_interleave(query_idx // key_idx, dim=3)
|
||||
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
|
||||
|
||||
# 6. Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
cross_attention_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
adaln_lora_dim: int = 256,
|
||||
qk_norm: str = "rms_norm",
|
||||
out_bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
qk_norm=qk_norm,
|
||||
elementwise_affine=True,
|
||||
out_bias=out_bias,
|
||||
processor=CosmosAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
qk_norm=qk_norm,
|
||||
elementwise_affine=True,
|
||||
out_bias=out_bias,
|
||||
processor=CosmosAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
embedded_timestep: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
extra_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if extra_pos_emb is not None:
|
||||
hidden_states = hidden_states + extra_pos_emb
|
||||
|
||||
# 1. Self Attention
|
||||
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
|
||||
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
|
||||
|
||||
# 2. Cross Attention
|
||||
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
|
||||
|
||||
# 3. Feed Forward
|
||||
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
max_size: Tuple[int, int, int] = (128, 240, 240),
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
base_fps: int = 24,
|
||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
|
||||
self.patch_size = patch_size
|
||||
self.base_fps = base_fps
|
||||
|
||||
self.dim_h = hidden_size // 6 * 2
|
||||
self.dim_w = hidden_size // 6 * 2
|
||||
self.dim_t = hidden_size - self.dim_h - self.dim_w
|
||||
|
||||
self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
|
||||
self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
|
||||
self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
|
||||
device = hidden_states.device
|
||||
|
||||
h_theta = 10000.0 * self.h_ntk_factor
|
||||
w_theta = 10000.0 * self.w_ntk_factor
|
||||
t_theta = 10000.0 * self.t_ntk_factor
|
||||
|
||||
seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
|
||||
dim_h_range = (
|
||||
torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
|
||||
)
|
||||
dim_w_range = (
|
||||
torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
|
||||
)
|
||||
dim_t_range = (
|
||||
torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
|
||||
)
|
||||
h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
|
||||
w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
|
||||
temporal_freqs = 1.0 / (t_theta**dim_t_range)
|
||||
|
||||
emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].repeat(pe_size[0], 1, pe_size[2], 1)
|
||||
emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].repeat(pe_size[0], pe_size[1], 1, 1)
|
||||
|
||||
# Apply sequence scaling in temporal dimension
|
||||
if fps is None:
|
||||
# Images
|
||||
emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
|
||||
else:
|
||||
# Videos
|
||||
emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
|
||||
|
||||
emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)
|
||||
freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class CosmosLearnablePositionalEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
max_size: Tuple[int, int, int],
|
||||
patch_size: Tuple[int, int, int],
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
|
||||
self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
|
||||
self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
|
||||
|
||||
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
|
||||
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
|
||||
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
|
||||
emb = emb_t + emb_h + emb_w
|
||||
emb = emb.flatten(1, 3)
|
||||
|
||||
norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
|
||||
return (emb / norm).type_as(hidden_states)
|
||||
|
||||
|
||||
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `32`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each attention head.
|
||||
num_layers (`int`, defaults to `28`):
|
||||
The number of layers of transformer blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
adaln_lora_dim (`int`, defaults to `256`):
|
||||
The hidden dimension of the Adaptive LayerNorm LoRA layer.
|
||||
max_size (`Tuple[int, int, int]`, defaults to `(128, 240, 240)`):
|
||||
The maximum size of the input latent tensors in the temporal, height, and width dimensions.
|
||||
patch_size (`Tuple[int, int, int]`, defaults to `(1, 2, 2)`):
|
||||
The patch size to use for patchifying the input latent tensors in the temporal, height, and width
|
||||
dimensions.
|
||||
rope_scale (`Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)`):
|
||||
The scaling factor to use for RoPE in the temporal, height, and width dimensions.
|
||||
concat_padding_mask (`bool`, defaults to `True`):
|
||||
Whether to concatenate the padding mask to the input latent tensors.
|
||||
extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
|
||||
The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
|
||||
_no_split_modules = ["CosmosTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["learnable_pos_embed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 28,
|
||||
mlp_ratio: float = 4.0,
|
||||
text_embed_dim: int = 1024,
|
||||
adaln_lora_dim: int = 256,
|
||||
max_size: Tuple[int, int, int] = (128, 240, 240),
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||
concat_padding_mask: bool = True,
|
||||
extra_pos_embed_type: Optional[str] = "learnable",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
patch_embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, patch_size, bias=False)
|
||||
|
||||
# 2. Positional Embedding
|
||||
self.rope = CosmosRotaryPosEmbed(
|
||||
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
|
||||
)
|
||||
|
||||
self.learnable_pos_embed = None
|
||||
if extra_pos_embed_type == "learnable":
|
||||
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
|
||||
hidden_size=hidden_size,
|
||||
max_size=max_size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
# 3. Time Embedding
|
||||
self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
|
||||
|
||||
# 4. Transformer Blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CosmosTransformerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=text_embed_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
qk_norm="rms_norm",
|
||||
out_bias=False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output norm & projection
|
||||
self.norm_out = CosmosAdaLayerNorm(hidden_size, adaln_lora_dim)
|
||||
self.proj_out = nn.Linear(
|
||||
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[int] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
# 1. Concatenate padding mask if needed & prepare attention mask
|
||||
if condition_mask is not None:
|
||||
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
|
||||
|
||||
if self.config.concat_padding_mask:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
|
||||
|
||||
# 2. Generate positional embeddings
|
||||
image_rotary_emb = self.rope(hidden_states, fps=fps)
|
||||
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
|
||||
|
||||
# 3. Patchify input
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
||||
|
||||
# 4. Timestep embeddings
|
||||
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
||||
|
||||
# 5. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
embedded_timestep,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
extra_pos_emb,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
embedded_timestep=embedded_timestep,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
extra_pos_emb=extra_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
# 6. Output norm & projection & unpatchify
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
|
||||
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
|
||||
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
|
||||
# Another few hours of sanity lost to the void.
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
@@ -156,6 +156,8 @@ else:
|
||||
]
|
||||
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
@@ -546,6 +548,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
|
||||
50
src/diffusers/pipelines/cosmos/__init__.py
Normal file
50
src/diffusers/pipelines/cosmos/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
||||
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
||||
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
667
src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py
Normal file
667
src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py
Normal file
@@ -0,0 +1,667 @@
|
||||
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
|
||||
from ...schedulers import EDMEulerScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CosmosTextToWorldPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
|
||||
>>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
|
||||
|
||||
>>> output = pipe(prompt=prompt).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=30)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Cosmos uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLCosmos`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLCosmos,
|
||||
scheduler: EDMEulerScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=True,
|
||||
return_offsets_mapping=False,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
||||
).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
lengths = prompt_attention_mask.sum(dim=1).cpu()
|
||||
for i, length in enumerate(lengths):
|
||||
prompt_embeds[i, length:] = 0
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents * self.scheduler.config.sigma_max
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 7.0,
|
||||
fps: int = 30,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `720`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`.
|
||||
fps (`int`, defaults to `30`):
|
||||
The frames per second of the generated video.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
self.safety_checker.to("cpu")
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
|
||||
|
||||
latent_model_input = latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = latent_model_input.to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
sample = latents
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
|
||||
sample = torch.cat([sample, sample])
|
||||
|
||||
# pred_original_sample (x0)
|
||||
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
|
||||
self.scheduler._step_index -= 1
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# pred_sample (eps)
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
if self.vae.config.latents_mean is not None:
|
||||
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
|
||||
latents_mean = (
|
||||
torch.tensor(latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(latents_std)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
|
||||
else:
|
||||
latents = latents / self.scheduler.config.sigma_data
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
self.safety_checker.to("cpu")
|
||||
else:
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
828
src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py
Normal file
828
src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py
Normal file
@@ -0,0 +1,828 @@
|
||||
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
|
||||
from ...schedulers import EDMEulerScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
Image conditioning:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CosmosVideoToWorldPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
|
||||
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day."
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
|
||||
... )
|
||||
|
||||
>>> video = pipe(image=image, prompt=prompt).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
Video conditioning:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CosmosVideoToWorldPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
|
||||
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.transformer = torch.compile(pipe.transformer)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
|
||||
>>> video = load_video(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
|
||||
... )[
|
||||
... :21
|
||||
... ] # This example uses only the first 21 frames
|
||||
|
||||
>>> video = pipe(video=video, prompt=prompt).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=30)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Cosmos uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLCosmos`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLCosmos,
|
||||
scheduler: EDMEulerScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=True,
|
||||
return_offsets_mapping=False,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
||||
).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
lengths = prompt_attention_mask.sum(dim=1).cpu()
|
||||
for i, length in enumerate(lengths):
|
||||
prompt_embeds[i, length:] = 0
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_channels_latents: 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
input_frames_guidance: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
num_cond_frames = video.size(2)
|
||||
if num_cond_frames >= num_frames:
|
||||
# Take the last `num_frames` frames for conditioning
|
||||
num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
video = video[:, :, -num_frames:]
|
||||
else:
|
||||
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
num_padding_frames = num_frames - num_cond_frames
|
||||
padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4))
|
||||
video = torch.cat([video, padding], dim=2)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
|
||||
if self.vae.config.latents_mean is not None:
|
||||
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
|
||||
latents_mean = (
|
||||
torch.tensor(latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
|
||||
.to(init_latents)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(latents_std)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
|
||||
.to(init_latents)
|
||||
)
|
||||
init_latents = (init_latents - latents_mean) * self.scheduler.config.sigma_data / latents_std
|
||||
else:
|
||||
init_latents = init_latents * self.scheduler.config.sigma_data
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latents = latents * self.scheduler.config.sigma_max
|
||||
|
||||
padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
uncond_indicator = uncond_mask = None
|
||||
if do_classifier_free_guidance:
|
||||
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
|
||||
uncond_mask = zeros_padding
|
||||
if not input_frames_guidance:
|
||||
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
|
||||
|
||||
return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
image=None,
|
||||
video=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if image is None and video is None:
|
||||
raise ValueError("Either `image` or `video` has to be provided.")
|
||||
if image is not None and video is not None:
|
||||
raise ValueError("Only one of `image` or `video` has to be provided.")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput = None,
|
||||
video: List[PipelineImageInput] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 7.0,
|
||||
input_frames_guidance: bool = False,
|
||||
augment_sigma: float = 0.001,
|
||||
fps: int = 30,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `720`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`.
|
||||
fps (`int`, defaults to `30`):
|
||||
The frames per second of the generated video.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, image, video)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
self.safety_checker.to("cpu")
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
if image is not None:
|
||||
video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
|
||||
else:
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
|
||||
video,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
self.do_classifier_free_guidance,
|
||||
input_frames_guidance,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
if self.do_classifier_free_guidance:
|
||||
uncond_mask = uncond_mask.to(transformer_dtype)
|
||||
|
||||
augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32)
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
|
||||
|
||||
current_sigma = self.scheduler.sigmas[i]
|
||||
is_augment_sigma_greater = augment_sigma >= current_sigma
|
||||
|
||||
c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma)
|
||||
c_in_original = self.scheduler._get_conditioning_c_in(current_sigma)
|
||||
|
||||
current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
|
||||
cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
|
||||
cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
|
||||
cond_latent = cond_latent * c_in_augment / c_in_original
|
||||
cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
|
||||
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
|
||||
cond_latent = cond_latent.to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=cond_latent,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
fps=fps,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
sample = latents
|
||||
if self.do_classifier_free_guidance:
|
||||
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
|
||||
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
|
||||
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
|
||||
uncond_latent = uncond_latent * c_in_augment / c_in_original
|
||||
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
|
||||
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
|
||||
uncond_latent = uncond_latent.to(transformer_dtype)
|
||||
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=uncond_latent,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
fps=fps,
|
||||
condition_mask=uncond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
|
||||
sample = torch.cat([sample, sample])
|
||||
|
||||
# pred_original_sample (x0)
|
||||
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
|
||||
self.scheduler._step_index -= 1
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
|
||||
noise_pred_uncond = (
|
||||
current_uncond_indicator * conditioning_latents
|
||||
+ (1 - current_uncond_indicator) * noise_pred_uncond
|
||||
)
|
||||
noise_pred_cond = (
|
||||
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
|
||||
)
|
||||
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = (
|
||||
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred
|
||||
)
|
||||
|
||||
# pred_sample (eps)
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
if self.vae.config.latents_mean is not None:
|
||||
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
|
||||
latents_mean = (
|
||||
torch.tensor(latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(latents_std)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
|
||||
else:
|
||||
latents = latents / self.scheduler.config.sigma_data
|
||||
video = self.vae.decode(latents.to(vae_dtype), return_dict=False)[0]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
self.safety_checker.to("cpu")
|
||||
else:
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
20
src/diffusers/pipelines/cosmos/pipeline_output.py
Normal file
20
src/diffusers/pipelines/cosmos/pipeline_output.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class CosmosPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Cosmos pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -144,7 +144,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
@@ -568,5 +568,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -176,7 +176,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
@@ -703,5 +703,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -103,11 +103,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
|
||||
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
|
||||
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
sigmas = torch.arange(num_train_timesteps + 1, dtype=sigmas_dtype) / num_train_timesteps
|
||||
if sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
sigmas = sigmas.to(torch.float32)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
@@ -159,7 +161,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
@@ -230,18 +232,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
if sigmas is None:
|
||||
sigmas = torch.linspace(0, 1, self.num_inference_steps)
|
||||
sigmas = torch.linspace(0, 1, self.num_inference_steps, dtype=sigmas_dtype)
|
||||
elif isinstance(sigmas, float):
|
||||
sigmas = torch.tensor(sigmas, dtype=torch.float32)
|
||||
sigmas = torch.tensor(sigmas, dtype=sigmas_dtype)
|
||||
else:
|
||||
sigmas = sigmas
|
||||
sigmas = sigmas.to(sigmas_dtype)
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
|
||||
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
@@ -315,6 +318,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
pred_original_sample: Optional[torch.Tensor] = None,
|
||||
) -> Union[EDMEulerSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
@@ -378,7 +382,8 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
|
||||
if pred_original_sample is None:
|
||||
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
@@ -435,5 +440,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -62,9 +62,11 @@ from .import_utils import (
|
||||
get_objects_from_module,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_better_profanity_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
is_bs4_available,
|
||||
is_cosmos_guardrail_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_gguf_available,
|
||||
@@ -78,6 +80,7 @@ from .import_utils import (
|
||||
is_k_diffusion_version,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
@@ -85,6 +88,7 @@ from .import_utils import (
|
||||
is_optimum_quanto_version,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_pytorch_retinaface_available,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
|
||||
@@ -160,6 +160,21 @@ class AutoencoderKLCogVideoX(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLCosmos(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -430,6 +445,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CosmosTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DiTTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -392,6 +392,51 @@ class CogView4Pipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ConsisIDPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CosmosTextToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CosmosVideoToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -215,6 +215,10 @@ _gguf_available, _gguf_version = _is_package_available("gguf")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
|
||||
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
|
||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -353,6 +357,22 @@ def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
|
||||
def is_pytorch_retinaface_available():
|
||||
return _pytorch_retinaface_available
|
||||
|
||||
|
||||
def is_better_profanity_available():
|
||||
return _better_profanity_available
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return _nltk_available
|
||||
|
||||
|
||||
def is_cosmos_guardrail_available():
|
||||
return _cosmos_guardrail_available
|
||||
|
||||
|
||||
def is_hpu_available():
|
||||
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
||||
|
||||
@@ -505,6 +525,22 @@ QUANTO_IMPORT_ERROR = """
|
||||
install optimum-quanto`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYTORCH_RETINAFACE_IMPORT_ERROR = """
|
||||
{0} requires the pytorch_retinaface library but it was not found in your environment. You can install it with pip: `pip install pytorch_retinaface`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
BETTER_PROFANITY_IMPORT_ERROR = """
|
||||
{0} requires the better_profanity library but it was not found in your environment. You can install it with pip: `pip install better_profanity`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
NLTK_IMPORT_ERROR = """
|
||||
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
@@ -533,6 +569,9 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
||||
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
||||
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
|
||||
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
|
||||
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
|
||||
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
86
tests/models/autoencoders/test_models_autoencoder_cosmos.py
Normal file
86
tests/models/autoencoders/test_models_autoencoder_cosmos.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLCosmos
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLCosmos
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_cosmos_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"encoder_block_out_channels": (8, 8, 8, 8),
|
||||
"decode_block_out_channels": (8, 8, 8, 8),
|
||||
"attention_resolutions": (8,),
|
||||
"resolution": 64,
|
||||
"num_layers": 2,
|
||||
"patch_size": 4,
|
||||
"patch_type": "haar",
|
||||
"scaling_factor": 1.0,
|
||||
"spatial_compression_ratio": 4,
|
||||
"temporal_compression_ratio": 4,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
height = 32
|
||||
width = 32
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_cosmos_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CosmosEncoder3d",
|
||||
"CosmosDecoder3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Not sure why this test fails. Investigate later.")
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
153
tests/models/transformers/test_models_transformer_cosmos.py
Normal file
153
tests/models/transformers/test_models_transformer_cosmos.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import CosmosTransformer3DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CosmosTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
fps = 30
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
|
||||
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"attention_mask": attention_mask,
|
||||
"fps": fps,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"num_layers": 2,
|
||||
"mlp_ratio": 2,
|
||||
"text_embed_dim": 16,
|
||||
"adaln_lora_dim": 4,
|
||||
"max_size": (4, 32, 32),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 1.0, 1.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CosmosTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CosmosTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
fps = 30
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
|
||||
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
|
||||
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"attention_mask": attention_mask,
|
||||
"fps": fps,
|
||||
"condition_mask": condition_mask,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4 + 1,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"num_layers": 2,
|
||||
"mlp_ratio": 2,
|
||||
"text_embed_dim": 16,
|
||||
"adaln_lora_dim": 4,
|
||||
"max_size": (4, 32, 32),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 1.0, 1.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CosmosTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
0
tests/pipelines/cosmos/__init__.py
Normal file
0
tests/pipelines/cosmos/__init__.py
Normal file
47
tests/pipelines/cosmos/cosmos_guardrail.py
Normal file
47
tests/pipelines/cosmos/cosmos_guardrail.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ===== This file is an implementation of a dummy guardrail for the fast tests =====
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._dtype = torch.float32
|
||||
|
||||
def check_text_safety(self, prompt: str) -> bool:
|
||||
return True
|
||||
|
||||
def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
|
||||
return frames
|
||||
|
||||
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None:
|
||||
self._dtype = dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return None
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._dtype
|
||||
350
tests/pipelines/cosmos/test_cosmos.py
Normal file
350
tests/pipelines/cosmos/test_cosmos.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from .cosmos_guardrail import DummyCosmosSafetyChecker
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CosmosTextToWorldPipelineWrapper(CosmosTextToWorldPipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(*args, **kwargs):
|
||||
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
|
||||
return CosmosTextToWorldPipeline.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = CosmosTextToWorldPipelineWrapper
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = CosmosTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
num_layers=2,
|
||||
mlp_ratio=2,
|
||||
text_embed_dim=32,
|
||||
adaln_lora_dim=4,
|
||||
max_size=(4, 32, 32),
|
||||
patch_size=(1, 2, 2),
|
||||
rope_scale=(2.0, 1.0, 1.0),
|
||||
concat_padding_mask=True,
|
||||
extra_pos_embed_type="learnable",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLCosmos(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
encoder_block_out_channels=(8, 8, 8, 8),
|
||||
decode_block_out_channels=(8, 8, 8, 8),
|
||||
attention_resolutions=(8,),
|
||||
resolution=64,
|
||||
num_layers=2,
|
||||
patch_size=4,
|
||||
patch_type="haar",
|
||||
scaling_factor=1.0,
|
||||
spatial_compression_ratio=4,
|
||||
temporal_compression_ratio=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = EDMEulerScheduler(
|
||||
sigma_min=0.002,
|
||||
sigma_max=80,
|
||||
sigma_data=0.5,
|
||||
sigma_schedule="karras",
|
||||
num_train_timesteps=1000,
|
||||
prediction_type="epsilon",
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
|
||||
"safety_checker": DummyCosmosSafetyChecker(),
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
expected_video = torch.randn(9, 3, 32, 32)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
self.pipeline_class._optional_components.remove("safety_checker")
|
||||
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
|
||||
self.pipeline_class._optional_components.append("safety_checker")
|
||||
|
||||
def test_serialization_with_variants(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
model_components = [
|
||||
component_name
|
||||
for component_name, component in pipe.components.items()
|
||||
if isinstance(component, torch.nn.Module)
|
||||
]
|
||||
model_components.remove("safety_checker")
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||
|
||||
with open(f"{tmpdir}/model_index.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
for subfolder in os.listdir(tmpdir):
|
||||
if not os.path.isfile(subfolder) and subfolder in model_components:
|
||||
folder_path = os.path.join(tmpdir, subfolder)
|
||||
is_folder = os.path.isdir(folder_path) and subfolder in config
|
||||
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||
|
||||
def test_torch_dtype_dict(self):
|
||||
components = self.get_dummy_components()
|
||||
if not components:
|
||||
self.skipTest("No dummy components defined.")
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
specified_key = next(iter(components.keys()))
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(
|
||||
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
|
||||
)
|
||||
|
||||
for name, component in loaded_pipe.components.items():
|
||||
if name == "safety_checker":
|
||||
continue
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
|
||||
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
|
||||
self.assertEqual(
|
||||
component.dtype,
|
||||
expected_dtype,
|
||||
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
|
||||
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
|
||||
"too large and slow to run on CI."
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
363
tests/pipelines/cosmos/test_cosmos_video2world.py
Normal file
363
tests/pipelines/cosmos/test_cosmos_video2world.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCosmos, CosmosTransformer3DModel, CosmosVideoToWorldPipeline, EDMEulerScheduler
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from .cosmos_guardrail import DummyCosmosSafetyChecker
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CosmosVideoToWorldPipelineWrapper(CosmosVideoToWorldPipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(*args, **kwargs):
|
||||
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
|
||||
return CosmosVideoToWorldPipeline.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = CosmosVideoToWorldPipelineWrapper
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image", "video"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = CosmosTransformer3DModel(
|
||||
in_channels=4 + 1,
|
||||
out_channels=4,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
num_layers=2,
|
||||
mlp_ratio=2,
|
||||
text_embed_dim=32,
|
||||
adaln_lora_dim=4,
|
||||
max_size=(4, 32, 32),
|
||||
patch_size=(1, 2, 2),
|
||||
rope_scale=(2.0, 1.0, 1.0),
|
||||
concat_padding_mask=True,
|
||||
extra_pos_embed_type="learnable",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLCosmos(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
encoder_block_out_channels=(8, 8, 8, 8),
|
||||
decode_block_out_channels=(8, 8, 8, 8),
|
||||
attention_resolutions=(8,),
|
||||
resolution=64,
|
||||
num_layers=2,
|
||||
patch_size=4,
|
||||
patch_type="haar",
|
||||
scaling_factor=1.0,
|
||||
spatial_compression_ratio=4,
|
||||
temporal_compression_ratio=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = EDMEulerScheduler(
|
||||
sigma_min=0.002,
|
||||
sigma_max=80,
|
||||
sigma_data=0.5,
|
||||
sigma_schedule="karras",
|
||||
num_train_timesteps=1000,
|
||||
prediction_type="epsilon",
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
|
||||
"safety_checker": DummyCosmosSafetyChecker(),
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
image = PIL.Image.new("RGB", (image_width, image_height))
|
||||
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": image_height,
|
||||
"width": image_width,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
expected_video = torch.randn(9, 3, 32, 32)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
self.pipeline_class._optional_components.remove("safety_checker")
|
||||
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
|
||||
self.pipeline_class._optional_components.append("safety_checker")
|
||||
|
||||
def test_serialization_with_variants(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
model_components = [
|
||||
component_name
|
||||
for component_name, component in pipe.components.items()
|
||||
if isinstance(component, torch.nn.Module)
|
||||
]
|
||||
model_components.remove("safety_checker")
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||
|
||||
with open(f"{tmpdir}/model_index.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
for subfolder in os.listdir(tmpdir):
|
||||
if not os.path.isfile(subfolder) and subfolder in model_components:
|
||||
folder_path = os.path.join(tmpdir, subfolder)
|
||||
is_folder = os.path.isdir(folder_path) and subfolder in config
|
||||
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||
|
||||
def test_torch_dtype_dict(self):
|
||||
components = self.get_dummy_components()
|
||||
if not components:
|
||||
self.skipTest("No dummy components defined.")
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
specified_key = next(iter(components.keys()))
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(
|
||||
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
|
||||
)
|
||||
|
||||
for name, component in loaded_pipe.components.items():
|
||||
if name == "safety_checker":
|
||||
continue
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
|
||||
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
|
||||
self.assertEqual(
|
||||
component.dtype,
|
||||
expected_dtype,
|
||||
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
|
||||
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
|
||||
"too large and slow to run on CI."
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
@@ -2291,7 +2291,6 @@ class PipelineTesterMixin:
|
||||
self.skipTest("No dummy components defined.")
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
specified_key = next(iter(components.keys()))
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user