You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
143 lines
4.9 KiB
Python
143 lines
4.9 KiB
Python
from transformers import Wav2Vec2Config, Wav2Vec2Model
|
|
from transformers.modeling_outputs import BaseModelOutput
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
def get_mask_from_lengths(lengths, max_len=None):
|
|
lengths = lengths.to(torch.long)
|
|
if max_len is None:
|
|
max_len = torch.max(lengths).item()
|
|
|
|
ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
|
|
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
|
|
|
|
return mask
|
|
|
|
|
|
def linear_interpolation(features, seq_len):
|
|
features = features.transpose(1, 2)
|
|
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
|
return output_features.transpose(1, 2)
|
|
|
|
# the implementation of Wav2Vec2Model is borrowed from
|
|
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
|
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
|
class Wav2Vec2Model(Wav2Vec2Model):
|
|
def __init__(self, config: Wav2Vec2Config):
|
|
super().__init__(config)
|
|
|
|
def forward(
|
|
self,
|
|
input_values,
|
|
seq_len,
|
|
attention_mask=None,
|
|
mask_time_indices=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
#self.config.output_attentions = True
|
|
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
extract_features = self.feature_extractor(input_values)
|
|
extract_features = extract_features.transpose(1, 2)
|
|
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
|
|
|
if attention_mask is not None:
|
|
# compute reduced attention_mask corresponding to feature vectors
|
|
attention_mask = self._get_feature_vector_attention_mask(
|
|
extract_features.shape[1], attention_mask, add_adapter=False
|
|
)
|
|
|
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
|
hidden_states = self._mask_hidden_states(
|
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = encoder_outputs[0]
|
|
|
|
if self.adapter is not None:
|
|
hidden_states = self.adapter(hidden_states)
|
|
|
|
if not return_dict:
|
|
return (hidden_states, ) + encoder_outputs[1:]
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
def feature_extract(
|
|
self,
|
|
input_values,
|
|
seq_len,
|
|
):
|
|
extract_features = self.feature_extractor(input_values)
|
|
extract_features = extract_features.transpose(1, 2)
|
|
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
|
|
|
return extract_features
|
|
|
|
def encode(
|
|
self,
|
|
extract_features,
|
|
attention_mask=None,
|
|
mask_time_indices=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
#self.config.output_attentions = True
|
|
|
|
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if attention_mask is not None:
|
|
# compute reduced attention_mask corresponding to feature vectors
|
|
attention_mask = self._get_feature_vector_attention_mask(
|
|
extract_features.shape[1], attention_mask, add_adapter=False
|
|
)
|
|
|
|
|
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
|
hidden_states = self._mask_hidden_states(
|
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = encoder_outputs[0]
|
|
|
|
if self.adapter is not None:
|
|
hidden_states = self.adapter(hidden_states)
|
|
|
|
if not return_dict:
|
|
return (hidden_states, ) + encoder_outputs[1:]
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|