diff --git a/audio_processing.py b/audio_processing.py new file mode 100644 index 0000000..86aa124 --- /dev/null +++ b/audio_processing.py @@ -0,0 +1,28 @@ +import numpy as np +import scipy, librosa + +def normalize(wav): + return wav / np.max( np.abs(wav) ) + +def trim(wav, threshold=0.01): + cut = np.where((abs(wav)>threshold))[0] + wav = wav[cut[0]:(cut[-1]+1)] # Trimming + return wav + +def wav_to_spec(wav, sr, n_fft, n_hop): + _, _, Zxx = scipy.signal.stft(wav, fs=sr, nperseg=n_fft, noverlap=n_fft-n_hop) + return 20 * np.log10(np.maximum(np.abs(Zxx), 1e-8)) + +def wav_to_melspec(wav, sr, n_fft, n_hop, n_mels, mel_basis=None): + _, _, Zxx = scipy.signal.stft(wav, fs=sr, nperseg=n_fft, noverlap=n_fft-n_hop) + if mel_basis is None: + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + D = np.matmul(mel_basis, np.abs(Zxx)) + return 20 * np.log10(np.maximum(D, 1e-8)) + +def spec_to_mel(spectrogram, sr, n_fft, n_mels, mel_basis=None): + S = np.power(10.0, spectrogram * 0.05) + if mel_basis is None: + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + D = np.matmul(mel_basis, S) + return 20 * np.log10(np.maximum(D, 1e-8)) \ No newline at end of file diff --git a/data.py b/data.py new file mode 100644 index 0000000..b319e0c --- /dev/null +++ b/data.py @@ -0,0 +1,111 @@ +from torch.utils.data import Dataset, DataLoader +import scipy, librosa +from audio_processing import * +import os +import torch +import random +import hparams as hp + +class Spliter(): + def __init__(self): + self.split=0 + + def reset(self): + self.split=0 + + def __call__(self, melspec): + if self.split%2==0: + self.split += 1 + return melspec[0::2, :], melspec[1::2, :] + + elif self.split%2==1: + self.split += 1 + return melspec[:, 0::2], melspec[:, 1::2] + +def pad_mel(melspecs): + B, F, T = len(melspecs), melspecs[0].shape[0], max([x.shape[1] for x in melspecs]) + padded_mel = np.zeros((B, F, T)) + for i, mel in enumerate(melspecs): + padded_mel[i, :, :mel.shape[1]] = mel + + return torch.from_numpy(padded_mel).to(torch.float) + + +class MelData(Dataset): + def __init__(self, hp): + super(Dataset, self).__init__() + self.root_dir = hp.root_dir + self.n_tiers = hp.n_tiers + self.sr = hp.sr + self.n_fft = hp.n_fft + self.n_mels = hp.n_mels + self.n_hop = hp.n_hop + self.n_overlap = hp.n_fft - hp.n_hop + self.n_bucket = hp.n_bucket + + self.wav_files = list(filter(lambda f: f.endswith('.wav'), os.listdir(self.root_dir))) + self.split = Spliter() + self.mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels) + self.bucket_by_sequence_length() + + def bucket_by_sequence_length(self): + ##### Wav -> Melspectrogram ##### + self.wav_lengths = [] + for wav_file in self.wav_files: + wav, _ = librosa.load( os.path.join(self.root_dir, wav_file) ) + self.wav_lengths.append(len(wav)) + + self.wav_length = list(zip(self.wav_files, self.wav_lengths)) + self.wav_length.sort(key = lambda x: x[1]) + + self.buckets = {} + bucket_size = len(self.wav_length)//self.n_bucket + for i in range(self.n_bucket): + self.buckets[i] = self.wav_length[i*bucket_size : (i+1)*bucket_size] + + def shuffle(self): + for i in range(self.n_bucket): + random.shuffle(self.buckets[i]) + + def __getitem__(self, i): + ##### Wav -> Melspectrogram ##### + _, wav = scipy.io.wavfile.read( os.path.join(self.root_dir, self.wav_files[i]) ) + wav = normalize(wav) + wav = trim(wav) + melspec = wav_to_melspec(wav, self.sr, self.n_fft, self.n_hop, self.n_mels, self.mel_basis) + + ##### Melspectrogram Validation ##### + n_half_t = (self.n_tiers-1) // 2 + n_time = melspec.shape[1] - melspec.shape[1] % 2**n_half_t + melspec = melspec[:, :n_time] + + ##### Build mel_tiers ##### + mel_tiers = [None] + [ 0 for _ in range(self.n_tiers) ] + for t in range(self.n_tiers, 1, -1): + tier, melspec = self.split(melspec) + mel_tiers[t] = tier + mel_tiers[1] = melspec + self.split.reset() + + return mel_tiers + + def __len__(self): + return len(self.wav_files) + + + +class MelCollate(): + def __init__(self, hp): + self.n_tiers = hp.n_tiers + + def __call__(self, batch): + mel_tiers = [None] + [ [] for _ in range(self.n_tiers) ] + + for data in batch: + for t in range(1, self.n_tiers+1): + mel_tiers[t].append(data[t]) + + for t in range(1, self.n_tiers+1): + mel_tiers[t] = pad_mel(mel_tiers[t]) + + return mel_tiers \ No newline at end of file diff --git a/hparams.py b/hparams.py new file mode 100644 index 0000000..b05af8a --- /dev/null +++ b/hparams.py @@ -0,0 +1,13 @@ +root_dir = '../dataset/KSS/wavs' +n_tiers = 6 +sr = 22050 +n_fft = 6 * 256 +n_hop = 256 +n_mels = 128 + + +batch_size = 8 +hidden_dim = 512 +epochs=5 +use_central=True +n_bucket=10 \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..b0df808 --- /dev/null +++ b/model.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +from math import ceil +from module import Stack, UpperStack + +class MelNet(nn.Module): + def __init__(self, hp): + super(MelNet, self).__init__() + self.first_stack = Stack(hp, use_central=True) + self.rest_stack = nn.ModuleList([Stack(hp, use_central=False) for _ in range(3)]) + self.upper_stack = nn.ModuleList([UpperStack(hp) for _ in range(3)]) + + #self.pred_layer = nn.Sequential( nn.Linear(hp.hidden_dim, 1), nn.Sigmoid() ) + self.pred_layer = nn.Linear(hp.hidden_dim, 1) + + self.time_expand = nn.Linear(1, hp.hidden_dim) + self.freq_expand = nn.Linear(1, hp.hidden_dim) + + self.use_central = hp.use_central + if self.use_central==True: + self.central_expand = nn.Linear( int(hp.n_mels*(0.5)**(ceil((hp.n_tiers-1)/2))) , hp.hidden_dim) + + def Tier1(self, x): + B,F,T = x.size() + GO_time, GO_freq = x.new_zeros(B,F,1), x.new_zeros(B,1,T) + x_t, x_f = torch.cat([GO_time, x[:,:,1:]], dim=2).unsqueeze(-1), torch.cat([GO_freq, x[:,1:,:]], dim=1).unsqueeze(-1) + x_t = self.time_expand(x_t) + x_f = self.freq_expand(x_f) + + if self.use_central==True: + x_c = self.central_expand(x.transpose(1,2)) + else: + x_c = None + + time_out, freq_out = self.first_stack( x_t, x_f, x_c ) + for stack in self.rest_stack: + time_out, freq_out = stack(time_out, freq_out) + + out = self.pred_layer(freq_out).squeeze(-1) + + return out + + def not_Tier1(self, x): + raise NotImplementedError + + def interleave(self, tier_n, tier_m): + raise NotImplementedError + + def forward(self, x): + tier1 = self.Tier1(x) + tier2 = self.not_Tier1(tier1) + tier3 = self.not_Tier1(self.interleave(tier1, tier2), target=3) + tier4 = self.not_Tier1(self.interleave(tier2, tier3), target=4) + tier5 = self.not_Tier1(self.interleave(tier3, tier4), target=5) + tier6 = self.not_Tier1(self.interleave(tier4, tier5), target=6) + return tier1, tier2, tier3, tier4, tier5, tier6 + + def infer(self): + raise NotImplementedError \ No newline at end of file diff --git a/module.py b/module.py new file mode 100644 index 0000000..d80183f --- /dev/null +++ b/module.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +from math import ceil + +class Stack(nn.Module): + def __init__(self, hp, use_central=False): + super(Stack, self).__init__() + self.time_rnn1 = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True) + self.time_rnn2 = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True) + self.time_rnn3 = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True) + self.freq_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True) + + self.time_linear = nn.Linear(hp.hidden_dim*3, hp.hidden_dim) + self.freq_linear = nn.Linear(hp.hidden_dim, hp.hidden_dim) + + self.use_central=use_central + if use_central==True: + self.central_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True) + self.central_linear = nn.Linear(hp.hidden_dim, hp.hidden_dim) + + + + def time_stack(self, x_t): + B, F, T = x_t.size()[:-1] + + direction_right = [] + for i in range( F ): + direction_right.append(self.time_rnn1(x_t[:,i])[0].unsqueeze(1)) + direction_right = torch.cat(direction_right, dim=1) + + direction_up = [] + for i in range( T ): + direction_up.append(self.time_rnn2(x_t[:,:,i])[0].unsqueeze(2)) + direction_up = torch.cat(direction_up, dim=2) + + direction_down = [] + for i in range( T, 0, -1 ): + direction_down.append(self.time_rnn3(x_t[:,:,i-1])[0].unsqueeze(2)) + direction_down = torch.cat(direction_down, dim=2) + + out = torch.cat([direction_right, direction_up, direction_down], dim=-1) + + out = self.time_linear(out) + x_t + + return out + + + def freq_stack(self, x_t, x_f, x_c=None): + B, F, T = x_t.size()[:-1] + + x_input = x_t + x_f + if x_c is not None: + x_input = x_input + x_c.unsqueeze(1) + + direction_up = [] + for i in range( T ): + direction_up.append(self.freq_rnn(x_input[:,:,i])[0].unsqueeze(2)) + direction_up = torch.cat(direction_up, dim=2) + + out = self.freq_linear(direction_up) + x_f + + return out + + + def cent_stack(self, x_c): + B, T = x_c.size()[:-1] + out = self.cent_rnn(x_c) + out = self.cent_linear(out) + out = out + x_c + return out + + + def forward(self, x_t, x_f, x_c=None): + time_out = self.time_stack(x_t) + freq_out = self.freq_stack(time_out, x_f, x_c) + + return time_out, freq_out + + + +class UpperStack(nn.Module): + def __init__(self, hp): + super(UpperStack, self).__init__() + self.time_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, bidirectional=True, batch_first=True) + self.freq_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, bidirectional=True, batch_first=True) + + self.linear = nn.Linear(hp.hidden_dim*4, 1) + + def forward(self, x): + time_out = self.time_rnn(x.transpose(1,2).contiguous()) + freq_out = self.freq_rnn(x) + + out = torch.cat([time_out, freq_out], dim=-1) + out = self.linear(out) + + return out \ No newline at end of file diff --git a/train.ipynb b/train.ipynb new file mode 100644 index 0000000..047a674 --- /dev/null +++ b/train.ipynb @@ -0,0 +1,108 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import hparams as hp\n", + "from model import MelNet\n", + "from torch.utils.data import DataLoader\n", + "from data import MelData, MelCollate\n", + "import matplotlib.pyplot as plt\n", + "from utils import *\n", + "import numpy as np\n", + "%matplotlib inline\n", + "np.set_printoptions(precision=3, suppress=True)\n", + "dataset = MelData(hp)\n", + "collate_fn = MelCollate(hp)\n", + "dataloader = DataLoader(dataset, batch_size = hp.batch_size, shuffle=True, drop_last=True, collate_fn = collate_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "device = torch.device('cuda:1')\n", + "model = MelNet(hp)\n", + "model.to(device)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)\n", + "criterion = nn.L1Loss()\n", + "\n", + "print(\"Training Start!!!\")\n", + "iteration=1\n", + "for epoch in range(hp.epochs):\n", + " for tiers in dataloader:\n", + " lr_scheduling(optimizer, iteration, 3e-4, 10)\n", + " \n", + " model.zero_grad()\n", + " target = tiers[1].to(device)\n", + " y_pred = model(target)\n", + "\n", + " loss = criterion(y_pred, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " iteration += 1\n", + " \n", + " if iteration%20==0:\n", + " fig = plt.figure(figsize=(20,10))\n", + " ax1, ax2, ax3 = fig.add_subplot(221), fig.add_subplot(222), fig.add_subplot(223)\n", + " ax1.imshow(target[0].cpu().numpy(), origin='lower')\n", + " ax2.imshow(y_pred[0].cpu().detach().numpy(), origin='lower')\n", + " ax3.imshow(np.tile(y_pred[0][0:1].cpu().detach().numpy(), (16,1)), origin='lower')\n", + " plt.show()\n", + " print(y_pred[0][0].cpu().detach().numpy()[:10])\n", + " print(y_pred[0][1].cpu().detach().numpy()[:10])\n", + " print(loss.item())\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow( y_pred[0].cpu().detach().numpy(), origin='lower')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred[0][2]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:LYH] *", + "language": "python", + "name": "conda-env-LYH-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..18d3e82 --- /dev/null +++ b/utils.py @@ -0,0 +1,211 @@ +import math +import torch +from torch.optim.optimizer import Optimizer, required + +def lr_scheduling(opt, step, init_lr=1e-3, warmup_steps=4000): + opt.param_groups[0]['lr'] = init_lr * warmup_steps ** 0.5 * min(step * warmup_steps ** -1.5, step ** -0.5) + return + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + +class PlainRAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + + +class AdamW(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, warmup = warmup) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + if group['warmup'] > state['step']: + scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] + else: + scheduled_lr = group['lr'] + + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) + + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + + p.data.copy_(p_data_fp32) + + return loss \ No newline at end of file diff --git a/wav_length_pair.pkl b/wav_length_pair.pkl new file mode 100644 index 0000000..2ed15dd Binary files /dev/null and b/wav_length_pair.pkl differ