From 77b4c13e80c6182c1ada4a8e7164b788a20e3891 Mon Sep 17 00:00:00 2001 From: Moughees Ahmed Date: Fri, 8 Mar 2024 15:09:44 -0500 Subject: [PATCH] add aligner Net --- aligner.py | 235 ++++++++++++++++++++++++++++++++++++++++++++++++++++ models.py | 81 +++++++++++------- train_ms.py | 3 +- 3 files changed, 288 insertions(+), 31 deletions(-) create mode 100644 aligner.py diff --git a/aligner.py b/aligner.py new file mode 100644 index 0000000..d89fe1b --- /dev/null +++ b/aligner.py @@ -0,0 +1,235 @@ +from typing import Tuple +import numpy as np + +import torch +from torch import nn, Tensor +from torch.nn import Module +import torch.nn.functional as F + +from einops import rearrange, repeat + +from beartype import beartype +from beartype.typing import Optional + +def exists(val): + return val is not None + +class AlignerNet(Module): + """alignment model https://arxiv.org/pdf/2108.10447.pdf """ + def __init__( + self, + dim_in=80, + dim_hidden=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + + self.key_layers = nn.ModuleList([ + nn.Conv1d( + dim_hidden, + dim_hidden * 2, + kernel_size=3, + padding=1, + bias=True, + ), + nn.ReLU(inplace=True), + nn.Conv1d(dim_hidden * 2, attn_channels, kernel_size=1, padding=0, bias=True) + ]) + + self.query_layers = nn.ModuleList([ + nn.Conv1d( + dim_in, + dim_in * 2, + kernel_size=3, + padding=1, + bias=True, + ), + nn.ReLU(inplace=True), + nn.Conv1d(dim_in * 2, dim_in, kernel_size=1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True) + ]) + + @beartype + def forward( + self, + queries: Tensor, + keys: Tensor, + mask: Optional[Tensor] = None + ): + key_out = keys + for layer in self.key_layers: + key_out = layer(key_out) + + query_out = queries + for layer in self.query_layers: + query_out = layer(query_out) + + key_out = rearrange(key_out, 'b c t -> b t c') + query_out = rearrange(query_out, 'b c t -> b t c') + + attn_logp = torch.cdist(query_out, key_out) + attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...') + + if exists(mask): + mask = rearrange(mask.bool(), '... c -> ... 1 c') + attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max) + + attn = attn_logp.softmax(dim = -1) + return attn, attn_logp + +def pad_tensor(input, pad, value=0): + pad = [item for sublist in reversed(pad) for item in sublist] # Flatten the tuple + assert len(pad) // 2 == len(input.shape), 'Padding dimensions do not match input dimensions' + return F.pad(input, pad, mode='constant', value=value) + +def maximum_path(value, mask, const=None): + device = value.device + dtype = value.dtype + if not exists(const): + const = torch.tensor(float('-inf')).to(device) # Patch for Sphinx complaint + value = value * mask + + b, t_x, t_y = value.shape + direction = torch.zeros(value.shape, dtype=torch.int64, device=device) + v = torch.zeros((b, t_x), dtype=torch.float32, device=device) + x_range = torch.arange(t_x, dtype=torch.float32, device=device).view(1, -1) + + for j in range(t_y): + v0 = pad_tensor(v, ((0, 0), (1, 0)), value = const)[:, :-1] + v1 = v + max_mask = v1 >= v0 + v_max = torch.where(max_mask, v1, v0) + direction[:, :, j] = max_mask + + index_mask = x_range <= j + v = torch.where(index_mask.view(1,-1), v_max + value[:, :, j], const) + + direction = torch.where(mask.bool(), direction, 1) + + path = torch.zeros(value.shape, dtype=torch.float32, device=device) + index = mask[:, :, 0].sum(1).long() - 1 + index_range = torch.arange(b, device=device) + + for j in reversed(range(t_y)): + path[index_range, index, j] = 1 + index = index + direction[index_range, index, j] - 1 + + path = path * mask.float() + path = path.to(dtype=dtype) + return path + +class ForwardSumLoss(Module): + def __init__( + self, + blank_logprob = -1 + ): + super().__init__() + self.blank_logprob = blank_logprob + + self.ctc_loss = torch.nn.CTCLoss( + blank = 0, # check this value + zero_infinity = True + ) + + def forward(self, attn_logprob, key_lens, query_lens): + device, blank_logprob = attn_logprob.device, self.blank_logprob + max_key_len = attn_logprob.size(-1) + + # Reorder input to [query_len, batch_size, key_len] + attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t') + + # Add blank label + attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob) + + # Convert to log probabilities + # Note: Mask out probs beyond key_len + mask_value = -torch.finfo(attn_logprob.dtype).max + attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value) + + attn_logprob = attn_logprob.log_softmax(dim = -1) + + # Target sequences + target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long) + target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel()) + + # Evaluate CTC loss + cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens) + + return cost + +class BinLoss(Module): + def forward(self, attn_hard, attn_logprob, key_lens): + batch, device = attn_logprob.shape[0], attn_logprob.device + max_key_len = attn_logprob.size(-1) + + # Reorder input to [query_len, batch_size, key_len] + attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t') + attn_hard = rearrange(attn_hard, 'b t c -> c b t') + + mask_value = -torch.finfo(attn_logprob.dtype).max + + attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value) + attn_logprob = attn_logprob.log_softmax(dim = -1) + + return (attn_hard * attn_logprob).sum() / batch + +class Aligner(Module): + def __init__( + self, + dim_in, + dim_hidden, + attn_channels=80, + temperature=0.0005 + ): + super().__init__() + self.dim_in = dim_in + self.dim_hidden = dim_hidden + self.attn_channels = attn_channels + self.temperature = temperature + self.aligner = AlignerNet( + dim_in = self.dim_in, + dim_hidden = self.dim_hidden, + attn_channels = self.attn_channels, + temperature = self.temperature + ) + + def forward( + self, + x, + x_mask, + y, + y_mask + ): + alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask) + + x_mask = rearrange(x_mask, '... i -> ... i 1') + y_mask = rearrange(y_mask, '... j -> ... 1 j') + attn_mask = x_mask * y_mask + attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j') + + alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c') + alignment_mask = maximum_path(alignment_soft, attn_mask) + + alignment_hard = torch.sum(alignment_mask, -1).int() + return alignment_hard, alignment_soft, alignment_logprob, alignment_mask + +if __name__ == '__main__': + batch_size = 10 + seq_len_y = 200 # length of sequence y + seq_len_x = 35 + feature_dim = 80 # feature dimension + + x = torch.randn(batch_size, 512, seq_len_x) + x = x.transpose(1,2) #dim-1 is the channels for conv + y = torch.randn(batch_size, seq_len_y, feature_dim) + y = y.transpose(1,2) #dim-1 is the channels for conv + + # Create masks + x_mask = torch.ones(batch_size, 1, seq_len_x) + y_mask = torch.ones(batch_size, 1, seq_len_y) + + align = Aligner(dim_in = 80, dim_hidden=512, attn_channels=80) + alignment_hard, alignment_soft, alignment_logprob, alignment_mas = align(x, x_mask, y, y_mask) \ No newline at end of file diff --git a/models.py b/models.py index dc8fa64..7519365 100644 --- a/models.py +++ b/models.py @@ -12,6 +12,7 @@ import modules import monotonic_align from commons import get_padding, init_weights +from .aligner import Aligner, ForwardSumLoss, BinLoss AVAILABLE_FLOW_TYPES = [ "pre_conv", @@ -1231,6 +1232,16 @@ def __init__( self.dp = DurationPredictor( hidden_channels, 256, 3, 0.5, gin_channels=gin_channels ) + + self.aligner = Aligner( + dim_in=80, + dim_hidden=self.enc_gin_channels, + attn_channels=self.enc_gin_channels, + ) + + self.aligner_loss = ForwardSumLoss() + self.bin_loss = BinLoss() + self.aligner_bin_loss_weight = 0.0 if n_speakers > 1: self.emb_g = nn.Embedding(n_speakers, gin_channels) @@ -1245,37 +1256,46 @@ def forward(self, x, x_lengths, y, y_lengths, sid=None): z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) - with torch.no_grad(): - # negative cross-entropy - s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] - neg_cent1 = torch.sum( - -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True - ) # [b, 1, t_s] - neg_cent2 = torch.matmul( - -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r - ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent3 = torch.matmul( - z_p.transpose(1, 2), (m_p * s_p_sq_r) - ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent4 = torch.sum( - -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True - ) # [b, 1, t_s] - neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 - - if self.use_noise_scaled_mas: - epsilon = ( - torch.std(neg_cent) - * torch.randn_like(neg_cent) - * self.current_mas_noise_scale - ) - neg_cent = neg_cent + epsilon - - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = ( - monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) - .unsqueeze(1) - .detach() + # with torch.no_grad(): + # # negative cross-entropy + # s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + # neg_cent1 = torch.sum( + # -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True + # ) # [b, 1, t_s] + # neg_cent2 = torch.matmul( + # -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r + # ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + # neg_cent3 = torch.matmul( + # z_p.transpose(1, 2), (m_p * s_p_sq_r) + # ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + # neg_cent4 = torch.sum( + # -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True + # ) # [b, 1, t_s] + # neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + # if self.use_noise_scaled_mas: + # epsilon = ( + # torch.std(neg_cent) + # * torch.randn_like(neg_cent) + # * self.current_mas_noise_scale + # ) + # neg_cent = neg_cent + epsilon + + # attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # attn = ( + # monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) + # .unsqueeze(1) + # .detach() + # ) + + aln_hard, aln_soft, aln_log, aln_mask = self.aligner( + m_p.transpose(1,2), x_mask, y, y_mask ) + attn = aln_mask.transpose(1,2).unsqueeze(1) + align_loss = self.aligner_loss(aln_log, x_lengths, y_lengths) + if self.aligner_bin_loss_weight > 0.: + align_bin_loss = self.bin_loss(aln_mask, aln_log, x_lengths) * self.aligner_bin_loss_weight + align_loss = align_loss + align_bin_loss w = attn.sum(2) if self.use_sdp: @@ -1307,6 +1327,7 @@ def forward(self, x, x_lengths, y, y_lengths, sid=None): y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), (x, logw, logw_), + align_loss, ) def infer( diff --git a/train_ms.py b/train_ms.py index 17d4cd3..4876436 100644 --- a/train_ms.py +++ b/train_ms.py @@ -358,6 +358,7 @@ def train_and_evaluate( z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), (hidden_x, logw, logw_), + align_loss, ) = net_g(x, x_lengths, spec, spec_lengths, speakers) if ( @@ -437,7 +438,7 @@ def train_and_evaluate( if net_dur_disc is not None: y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw) with autocast(enabled=False): - loss_dur = torch.sum(l_length.float()) + loss_dur = torch.sum(l_length.float()) + align_loss loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl