Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TEST: Add Aligner Net #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions aligner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from typing import Tuple
import numpy as np

import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F

from einops import rearrange, repeat

from beartype import beartype
from beartype.typing import Optional

def exists(val):
return val is not None

class AlignerNet(Module):
"""alignment model https://arxiv.org/pdf/2108.10447.pdf """
def __init__(
self,
dim_in=80,
dim_hidden=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature

self.key_layers = nn.ModuleList([
nn.Conv1d(
dim_hidden,
dim_hidden * 2,
kernel_size=3,
padding=1,
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv1d(dim_hidden * 2, attn_channels, kernel_size=1, padding=0, bias=True)
])

self.query_layers = nn.ModuleList([
nn.Conv1d(
dim_in,
dim_in * 2,
kernel_size=3,
padding=1,
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv1d(dim_in * 2, dim_in, kernel_size=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True)
])

@beartype
def forward(
self,
queries: Tensor,
keys: Tensor,
mask: Optional[Tensor] = None
):
key_out = keys
for layer in self.key_layers:
key_out = layer(key_out)

query_out = queries
for layer in self.query_layers:
query_out = layer(query_out)

key_out = rearrange(key_out, 'b c t -> b t c')
query_out = rearrange(query_out, 'b c t -> b t c')

attn_logp = torch.cdist(query_out, key_out)
attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...')

if exists(mask):
mask = rearrange(mask.bool(), '... c -> ... 1 c')
attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max)

attn = attn_logp.softmax(dim = -1)
return attn, attn_logp

def pad_tensor(input, pad, value=0):
pad = [item for sublist in reversed(pad) for item in sublist] # Flatten the tuple
assert len(pad) // 2 == len(input.shape), 'Padding dimensions do not match input dimensions'
return F.pad(input, pad, mode='constant', value=value)

def maximum_path(value, mask, const=None):
device = value.device
dtype = value.dtype
if not exists(const):
const = torch.tensor(float('-inf')).to(device) # Patch for Sphinx complaint
value = value * mask

b, t_x, t_y = value.shape
direction = torch.zeros(value.shape, dtype=torch.int64, device=device)
v = torch.zeros((b, t_x), dtype=torch.float32, device=device)
x_range = torch.arange(t_x, dtype=torch.float32, device=device).view(1, -1)

for j in range(t_y):
v0 = pad_tensor(v, ((0, 0), (1, 0)), value = const)[:, :-1]
v1 = v
max_mask = v1 >= v0
v_max = torch.where(max_mask, v1, v0)
direction[:, :, j] = max_mask

index_mask = x_range <= j
v = torch.where(index_mask.view(1,-1), v_max + value[:, :, j], const)

direction = torch.where(mask.bool(), direction, 1)

path = torch.zeros(value.shape, dtype=torch.float32, device=device)
index = mask[:, :, 0].sum(1).long() - 1
index_range = torch.arange(b, device=device)

for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1

path = path * mask.float()
path = path.to(dtype=dtype)
return path

class ForwardSumLoss(Module):
def __init__(
self,
blank_logprob = -1
):
super().__init__()
self.blank_logprob = blank_logprob

self.ctc_loss = torch.nn.CTCLoss(
blank = 0, # check this value
zero_infinity = True
)

def forward(self, attn_logprob, key_lens, query_lens):
device, blank_logprob = attn_logprob.device, self.blank_logprob
max_key_len = attn_logprob.size(-1)

# Reorder input to [query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')

# Add blank label
attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob)

# Convert to log probabilities
# Note: Mask out probs beyond key_len
mask_value = -torch.finfo(attn_logprob.dtype).max
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)

attn_logprob = attn_logprob.log_softmax(dim = -1)

# Target sequences
target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long)
target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel())

# Evaluate CTC loss
cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens)

return cost

class BinLoss(Module):
def forward(self, attn_hard, attn_logprob, key_lens):
batch, device = attn_logprob.shape[0], attn_logprob.device
max_key_len = attn_logprob.size(-1)

# Reorder input to [query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
attn_hard = rearrange(attn_hard, 'b t c -> c b t')

mask_value = -torch.finfo(attn_logprob.dtype).max

attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
attn_logprob = attn_logprob.log_softmax(dim = -1)

return (attn_hard * attn_logprob).sum() / batch

class Aligner(Module):
def __init__(
self,
dim_in,
dim_hidden,
attn_channels=80,
temperature=0.0005
):
super().__init__()
self.dim_in = dim_in
self.dim_hidden = dim_hidden
self.attn_channels = attn_channels
self.temperature = temperature
self.aligner = AlignerNet(
dim_in = self.dim_in,
dim_hidden = self.dim_hidden,
attn_channels = self.attn_channels,
temperature = self.temperature
)

def forward(
self,
x,
x_mask,
y,
y_mask
):
alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask)

x_mask = rearrange(x_mask, '... i -> ... i 1')
y_mask = rearrange(y_mask, '... j -> ... 1 j')
attn_mask = x_mask * y_mask
attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j')

alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')
alignment_mask = maximum_path(alignment_soft, attn_mask)

alignment_hard = torch.sum(alignment_mask, -1).int()
return alignment_hard, alignment_soft, alignment_logprob, alignment_mask

if __name__ == '__main__':
batch_size = 10
seq_len_y = 200 # length of sequence y
seq_len_x = 35
feature_dim = 80 # feature dimension

x = torch.randn(batch_size, 512, seq_len_x)
x = x.transpose(1,2) #dim-1 is the channels for conv
y = torch.randn(batch_size, seq_len_y, feature_dim)
y = y.transpose(1,2) #dim-1 is the channels for conv

# Create masks
x_mask = torch.ones(batch_size, 1, seq_len_x)
y_mask = torch.ones(batch_size, 1, seq_len_y)

align = Aligner(dim_in = 80, dim_hidden=512, attn_channels=80)
alignment_hard, alignment_soft, alignment_logprob, alignment_mas = align(x, x_mask, y, y_mask)
81 changes: 51 additions & 30 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import modules
import monotonic_align
from commons import get_padding, init_weights
from .aligner import Aligner, ForwardSumLoss, BinLoss

AVAILABLE_FLOW_TYPES = [
"pre_conv",
Expand Down Expand Up @@ -1231,6 +1232,16 @@ def __init__(
self.dp = DurationPredictor(
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
)

self.aligner = Aligner(
dim_in=80,
dim_hidden=self.enc_gin_channels,
attn_channels=self.enc_gin_channels,
)

self.aligner_loss = ForwardSumLoss()
self.bin_loss = BinLoss()
self.aligner_bin_loss_weight = 0.0

if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
Expand All @@ -1245,37 +1256,46 @@ def forward(self, x, x_lengths, y, y_lengths, sid=None):
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)

with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
neg_cent1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
) # [b, 1, t_s]
neg_cent2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent3 = torch.matmul(
z_p.transpose(1, 2), (m_p * s_p_sq_r)
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4

if self.use_noise_scaled_mas:
epsilon = (
torch.std(neg_cent)
* torch.randn_like(neg_cent)
* self.current_mas_noise_scale
)
neg_cent = neg_cent + epsilon

attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = (
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
.unsqueeze(1)
.detach()
# with torch.no_grad():
# # negative cross-entropy
# s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
# neg_cent1 = torch.sum(
# -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
# ) # [b, 1, t_s]
# neg_cent2 = torch.matmul(
# -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
# ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
# neg_cent3 = torch.matmul(
# z_p.transpose(1, 2), (m_p * s_p_sq_r)
# ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
# neg_cent4 = torch.sum(
# -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
# ) # [b, 1, t_s]
# neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4

# if self.use_noise_scaled_mas:
# epsilon = (
# torch.std(neg_cent)
# * torch.randn_like(neg_cent)
# * self.current_mas_noise_scale
# )
# neg_cent = neg_cent + epsilon

# attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
# attn = (
# monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
# .unsqueeze(1)
# .detach()
# )

aln_hard, aln_soft, aln_log, aln_mask = self.aligner(
m_p.transpose(1,2), x_mask, y, y_mask
)
attn = aln_mask.transpose(1,2).unsqueeze(1)
align_loss = self.aligner_loss(aln_log, x_lengths, y_lengths)
if self.aligner_bin_loss_weight > 0.:
align_bin_loss = self.bin_loss(aln_mask, aln_log, x_lengths) * self.aligner_bin_loss_weight
align_loss = align_loss + align_bin_loss

w = attn.sum(2)
if self.use_sdp:
Expand Down Expand Up @@ -1307,6 +1327,7 @@ def forward(self, x, x_lengths, y, y_lengths, sid=None):
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
(x, logw, logw_),
align_loss,
)

def infer(
Expand Down
3 changes: 2 additions & 1 deletion train_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def train_and_evaluate(
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
(hidden_x, logw, logw_),
align_loss,
) = net_g(x, x_lengths, spec, spec_lengths, speakers)

if (
Expand Down Expand Up @@ -437,7 +438,7 @@ def train_and_evaluate(
if net_dur_disc is not None:
y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw)
with autocast(enabled=False):
loss_dur = torch.sum(l_length.float())
loss_dur = torch.sum(l_length.float()) + align_loss
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

Expand Down