mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* initial TokenEncoder and ContinuousEncoder * initial modules * added ContinuousContextTransformer * fix copy paste error * use numpy for get_sequence_length * initial terminal relative positional encodings * fix weights keys * fix assert * cross attend style: concat encodings * make style * concat once * fix formatting * Initial SpectrogramPipeline * fix input_tokens * make style * added mel output * ignore weights for config * move mel to numpy * import pipeline * fix class names and import * moved models to models folder * import ContinuousContextTransformer and SpectrogramDiffusionPipeline * initial spec diffusion converstion script * renamed config to t5config * added weight loading * use arguments instead of t5config * broadcast noise time to batch dim * fix call * added scale_to_features * fix weights * transpose laynorm weight * scale is a vector * scale the query outputs * added comment * undo scaling * undo depth_scaling * inital get_extended_attention_mask * attention_mask is none in self-attention * cleanup * manually invert attention * nn.linear need bias=False * added T5LayerFFCond * remove to fix conflict * make style and dummy * remove unsed variables * remove predict_epsilon * Move accelerate to a soft-dependency (#1134) * finish * finish * Update src/diffusers/modeling_utils.py * Update src/diffusers/pipeline_utils.py Co-authored-by: Anton Lozhkov <anton@huggingface.co> * more fixes * fix Co-authored-by: Anton Lozhkov <anton@huggingface.co> * fix order * added initial midi to note token data pipeline * added int to int tokenizer * remove duplicate * added logic for segments * add melgan to pipeline * move autoregressive gen into pipeline * added note_representation_processor_chain * fix dtypes * remove immutabledict req * initial doc * use np.where * require note_seq * fix typo * update dependency * added note-seq to test * added is_note_seq_available * fix import * added toc * added example usage * undo for now * moved docs * fix merge * fix imports * predict first segment * avoid un-needed copy to and from cpu * make style * Copyright * fix style * add test and fix inference steps * remove bogus files * reorder models * up * remove transformers dependency * make work with diffusers cross attention * clean more * remove @ * improve further * up * uP * Apply suggestions from code review * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py * loop over all tokens * make style * Added a section on the model * fix formatting * grammer * formatting * make fix-copies * Update src/diffusers/pipelines/__init__.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * added callback ad optional ionnx * do not squeeze batch dim * clean up more * upload * convert jax to nnumpy * make style * fix warning * make fix-copies * fix warning * add initial fast tests * add initial pipeline_params * eval mode due to dropout * skip batch tests as pipeline runs on a single file * make style * fix relative path * fix doc tests * Update src/diffusers/models/t5_film_transformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/t5_film_transformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update docs/source/en/api/pipelines/spectrogram_diffusion.mdx Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add MidiProcessor * format * fix org * Apply suggestions from code review * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py * make style * pin protobuf to <4 * fix formatting * white space * tensorboard needs protobuf --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
214 lines
10 KiB
Python
214 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
|
|
import jax as jnp
|
|
import numpy as onp
|
|
import torch
|
|
import torch.nn as nn
|
|
from music_spectrogram_diffusion import inference
|
|
from t5x import checkpoints
|
|
|
|
from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline
|
|
from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder
|
|
|
|
|
|
MODEL = "base_with_context"
|
|
|
|
|
|
def load_notes_encoder(weights, model):
|
|
model.token_embedder.weight = nn.Parameter(torch.FloatTensor(weights["token_embedder"]["embedding"]))
|
|
model.position_encoding.weight = nn.Parameter(
|
|
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
|
|
)
|
|
for lyr_num, lyr in enumerate(model.encoders):
|
|
ly_weight = weights[f"layers_{lyr_num}"]
|
|
lyr.layer[0].layer_norm.weight = nn.Parameter(
|
|
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
|
|
)
|
|
|
|
attention_weights = ly_weight["attention"]
|
|
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
|
|
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
|
|
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
|
|
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
|
|
|
|
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
|
|
|
|
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
|
|
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
|
|
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
|
|
|
|
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
|
|
return model
|
|
|
|
|
|
def load_continuous_encoder(weights, model):
|
|
model.input_proj.weight = nn.Parameter(torch.FloatTensor(weights["input_proj"]["kernel"].T))
|
|
|
|
model.position_encoding.weight = nn.Parameter(
|
|
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
|
|
)
|
|
|
|
for lyr_num, lyr in enumerate(model.encoders):
|
|
ly_weight = weights[f"layers_{lyr_num}"]
|
|
attention_weights = ly_weight["attention"]
|
|
|
|
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
|
|
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
|
|
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
|
|
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
|
|
lyr.layer[0].layer_norm.weight = nn.Parameter(
|
|
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
|
|
)
|
|
|
|
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
|
|
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
|
|
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
|
|
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
|
|
|
|
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
|
|
|
|
return model
|
|
|
|
|
|
def load_decoder(weights, model):
|
|
model.conditioning_emb[0].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense0"]["kernel"].T))
|
|
model.conditioning_emb[2].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense1"]["kernel"].T))
|
|
|
|
model.position_encoding.weight = nn.Parameter(
|
|
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
|
|
)
|
|
|
|
model.continuous_inputs_projection.weight = nn.Parameter(
|
|
torch.FloatTensor(weights["continuous_inputs_projection"]["kernel"].T)
|
|
)
|
|
|
|
for lyr_num, lyr in enumerate(model.decoders):
|
|
ly_weight = weights[f"layers_{lyr_num}"]
|
|
lyr.layer[0].layer_norm.weight = nn.Parameter(
|
|
torch.FloatTensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
|
|
)
|
|
|
|
lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(
|
|
torch.FloatTensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
|
|
)
|
|
|
|
attention_weights = ly_weight["self_attention"]
|
|
lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
|
|
lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
|
|
lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
|
|
lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
|
|
|
|
attention_weights = ly_weight["MultiHeadDotProductAttention_0"]
|
|
lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
|
|
lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
|
|
lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
|
|
lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
|
|
lyr.layer[1].layer_norm.weight = nn.Parameter(
|
|
torch.FloatTensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
|
|
)
|
|
|
|
lyr.layer[2].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
|
|
lyr.layer[2].film.scale_bias.weight = nn.Parameter(
|
|
torch.FloatTensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
|
|
)
|
|
lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
|
|
lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
|
|
lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
|
|
|
|
model.decoder_norm.weight = nn.Parameter(torch.FloatTensor(weights["decoder_norm"]["scale"]))
|
|
|
|
model.spec_out.weight = nn.Parameter(torch.FloatTensor(weights["spec_out_dense"]["kernel"].T))
|
|
|
|
return model
|
|
|
|
|
|
def main(args):
|
|
t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path)
|
|
t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint)
|
|
|
|
gin_overrides = [
|
|
"from __gin__ import dynamic_registration",
|
|
"from music_spectrogram_diffusion.models.diffusion import diffusion_utils",
|
|
"diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0",
|
|
"diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()",
|
|
]
|
|
|
|
gin_file = os.path.join(args.checkpoint_path, "..", "config.gin")
|
|
gin_config = inference.parse_training_gin_file(gin_file, gin_overrides)
|
|
synth_model = inference.InferenceModel(args.checkpoint_path, gin_config)
|
|
|
|
scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large")
|
|
|
|
notes_encoder = SpectrogramNotesEncoder(
|
|
max_length=synth_model.sequence_length["inputs"],
|
|
vocab_size=synth_model.model.module.config.vocab_size,
|
|
d_model=synth_model.model.module.config.emb_dim,
|
|
dropout_rate=synth_model.model.module.config.dropout_rate,
|
|
num_layers=synth_model.model.module.config.num_encoder_layers,
|
|
num_heads=synth_model.model.module.config.num_heads,
|
|
d_kv=synth_model.model.module.config.head_dim,
|
|
d_ff=synth_model.model.module.config.mlp_dim,
|
|
feed_forward_proj="gated-gelu",
|
|
)
|
|
|
|
continuous_encoder = SpectrogramContEncoder(
|
|
input_dims=synth_model.audio_codec.n_dims,
|
|
targets_context_length=synth_model.sequence_length["targets_context"],
|
|
d_model=synth_model.model.module.config.emb_dim,
|
|
dropout_rate=synth_model.model.module.config.dropout_rate,
|
|
num_layers=synth_model.model.module.config.num_encoder_layers,
|
|
num_heads=synth_model.model.module.config.num_heads,
|
|
d_kv=synth_model.model.module.config.head_dim,
|
|
d_ff=synth_model.model.module.config.mlp_dim,
|
|
feed_forward_proj="gated-gelu",
|
|
)
|
|
|
|
decoder = T5FilmDecoder(
|
|
input_dims=synth_model.audio_codec.n_dims,
|
|
targets_length=synth_model.sequence_length["targets_context"],
|
|
max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time,
|
|
d_model=synth_model.model.module.config.emb_dim,
|
|
num_layers=synth_model.model.module.config.num_decoder_layers,
|
|
num_heads=synth_model.model.module.config.num_heads,
|
|
d_kv=synth_model.model.module.config.head_dim,
|
|
d_ff=synth_model.model.module.config.mlp_dim,
|
|
dropout_rate=synth_model.model.module.config.dropout_rate,
|
|
)
|
|
|
|
notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder)
|
|
continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder)
|
|
decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder)
|
|
|
|
melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder")
|
|
|
|
pipe = SpectrogramDiffusionPipeline(
|
|
notes_encoder=notes_encoder,
|
|
continuous_encoder=continuous_encoder,
|
|
decoder=decoder,
|
|
scheduler=scheduler,
|
|
melgan=melgan,
|
|
)
|
|
if args.save:
|
|
pipe.save_pretrained(args.output_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.")
|
|
parser.add_argument(
|
|
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint_path",
|
|
default=f"{MODEL}/checkpoint_500000",
|
|
type=str,
|
|
required=False,
|
|
help="Path to the original jax model checkpoint.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|