Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions audio_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np
import scipy, librosa

def normalize(wav):
return wav / np.max( np.abs(wav) )

def trim(wav, threshold=0.01):
cut = np.where((abs(wav)>threshold))[0]
wav = wav[cut[0]:(cut[-1]+1)] # Trimming
return wav

def wav_to_spec(wav, sr, n_fft, n_hop):
_, _, Zxx = scipy.signal.stft(wav, fs=sr, nperseg=n_fft, noverlap=n_fft-n_hop)
return 20 * np.log10(np.maximum(np.abs(Zxx), 1e-8))

def wav_to_melspec(wav, sr, n_fft, n_hop, n_mels, mel_basis=None):
_, _, Zxx = scipy.signal.stft(wav, fs=sr, nperseg=n_fft, noverlap=n_fft-n_hop)
if mel_basis is None:
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
D = np.matmul(mel_basis, np.abs(Zxx))
return 20 * np.log10(np.maximum(D, 1e-8))

def spec_to_mel(spectrogram, sr, n_fft, n_mels, mel_basis=None):
S = np.power(10.0, spectrogram * 0.05)
if mel_basis is None:
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
D = np.matmul(mel_basis, S)
return 20 * np.log10(np.maximum(D, 1e-8))
111 changes: 111 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from torch.utils.data import Dataset, DataLoader
import scipy, librosa
from audio_processing import *
import os
import torch
import random
import hparams as hp

class Spliter():
def __init__(self):
self.split=0

def reset(self):
self.split=0

def __call__(self, melspec):
if self.split%2==0:
self.split += 1
return melspec[0::2, :], melspec[1::2, :]

elif self.split%2==1:
self.split += 1
return melspec[:, 0::2], melspec[:, 1::2]

def pad_mel(melspecs):
B, F, T = len(melspecs), melspecs[0].shape[0], max([x.shape[1] for x in melspecs])
padded_mel = np.zeros((B, F, T))
for i, mel in enumerate(melspecs):
padded_mel[i, :, :mel.shape[1]] = mel

return torch.from_numpy(padded_mel).to(torch.float)


class MelData(Dataset):
def __init__(self, hp):
super(Dataset, self).__init__()
self.root_dir = hp.root_dir
self.n_tiers = hp.n_tiers
self.sr = hp.sr
self.n_fft = hp.n_fft
self.n_mels = hp.n_mels
self.n_hop = hp.n_hop
self.n_overlap = hp.n_fft - hp.n_hop
self.n_bucket = hp.n_bucket

self.wav_files = list(filter(lambda f: f.endswith('.wav'), os.listdir(self.root_dir)))
self.split = Spliter()
self.mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels)
self.bucket_by_sequence_length()

def bucket_by_sequence_length(self):
##### Wav -> Melspectrogram #####
self.wav_lengths = []
for wav_file in self.wav_files:
wav, _ = librosa.load( os.path.join(self.root_dir, wav_file) )
self.wav_lengths.append(len(wav))

self.wav_length = list(zip(self.wav_files, self.wav_lengths))
self.wav_length.sort(key = lambda x: x[1])

self.buckets = {}
bucket_size = len(self.wav_length)//self.n_bucket
for i in range(self.n_bucket):
self.buckets[i] = self.wav_length[i*bucket_size : (i+1)*bucket_size]

def shuffle(self):
for i in range(self.n_bucket):
random.shuffle(self.buckets[i])

def __getitem__(self, i):
##### Wav -> Melspectrogram #####
_, wav = scipy.io.wavfile.read( os.path.join(self.root_dir, self.wav_files[i]) )
wav = normalize(wav)
wav = trim(wav)
melspec = wav_to_melspec(wav, self.sr, self.n_fft, self.n_hop, self.n_mels, self.mel_basis)

##### Melspectrogram Validation #####
n_half_t = (self.n_tiers-1) // 2
n_time = melspec.shape[1] - melspec.shape[1] % 2**n_half_t
melspec = melspec[:, :n_time]

##### Build mel_tiers #####
mel_tiers = [None] + [ 0 for _ in range(self.n_tiers) ]
for t in range(self.n_tiers, 1, -1):
tier, melspec = self.split(melspec)
mel_tiers[t] = tier
mel_tiers[1] = melspec
self.split.reset()

return mel_tiers

def __len__(self):
return len(self.wav_files)



class MelCollate():
def __init__(self, hp):
self.n_tiers = hp.n_tiers

def __call__(self, batch):
mel_tiers = [None] + [ [] for _ in range(self.n_tiers) ]

for data in batch:
for t in range(1, self.n_tiers+1):
mel_tiers[t].append(data[t])

for t in range(1, self.n_tiers+1):
mel_tiers[t] = pad_mel(mel_tiers[t])

return mel_tiers
13 changes: 13 additions & 0 deletions hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
root_dir = '../dataset/KSS/wavs'
n_tiers = 6
sr = 22050
n_fft = 6 * 256
n_hop = 256
n_mels = 128


batch_size = 8
hidden_dim = 512
epochs=5
use_central=True
n_bucket=10
59 changes: 59 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn as nn
from math import ceil
from module import Stack, UpperStack

class MelNet(nn.Module):
def __init__(self, hp):
super(MelNet, self).__init__()
self.first_stack = Stack(hp, use_central=True)
self.rest_stack = nn.ModuleList([Stack(hp, use_central=False) for _ in range(3)])
self.upper_stack = nn.ModuleList([UpperStack(hp) for _ in range(3)])

#self.pred_layer = nn.Sequential( nn.Linear(hp.hidden_dim, 1), nn.Sigmoid() )
self.pred_layer = nn.Linear(hp.hidden_dim, 1)

self.time_expand = nn.Linear(1, hp.hidden_dim)
self.freq_expand = nn.Linear(1, hp.hidden_dim)

self.use_central = hp.use_central
if self.use_central==True:
self.central_expand = nn.Linear( int(hp.n_mels*(0.5)**(ceil((hp.n_tiers-1)/2))) , hp.hidden_dim)

def Tier1(self, x):
B,F,T = x.size()
GO_time, GO_freq = x.new_zeros(B,F,1), x.new_zeros(B,1,T)
x_t, x_f = torch.cat([GO_time, x[:,:,1:]], dim=2).unsqueeze(-1), torch.cat([GO_freq, x[:,1:,:]], dim=1).unsqueeze(-1)
x_t = self.time_expand(x_t)
x_f = self.freq_expand(x_f)

if self.use_central==True:
x_c = self.central_expand(x.transpose(1,2))
else:
x_c = None

time_out, freq_out = self.first_stack( x_t, x_f, x_c )
for stack in self.rest_stack:
time_out, freq_out = stack(time_out, freq_out)

out = self.pred_layer(freq_out).squeeze(-1)

return out

def not_Tier1(self, x):
raise NotImplementedError

def interleave(self, tier_n, tier_m):
raise NotImplementedError

def forward(self, x):
tier1 = self.Tier1(x)
tier2 = self.not_Tier1(tier1)
tier3 = self.not_Tier1(self.interleave(tier1, tier2), target=3)
tier4 = self.not_Tier1(self.interleave(tier2, tier3), target=4)
tier5 = self.not_Tier1(self.interleave(tier3, tier4), target=5)
tier6 = self.not_Tier1(self.interleave(tier4, tier5), target=6)
return tier1, tier2, tier3, tier4, tier5, tier6

def infer(self):
raise NotImplementedError
96 changes: 96 additions & 0 deletions module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import torch.nn as nn
from math import ceil

class Stack(nn.Module):
def __init__(self, hp, use_central=False):
super(Stack, self).__init__()
self.time_rnn1 = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True)
self.time_rnn2 = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True)
self.time_rnn3 = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True)
self.freq_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True)

self.time_linear = nn.Linear(hp.hidden_dim*3, hp.hidden_dim)
self.freq_linear = nn.Linear(hp.hidden_dim, hp.hidden_dim)

self.use_central=use_central
if use_central==True:
self.central_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, batch_first=True)
self.central_linear = nn.Linear(hp.hidden_dim, hp.hidden_dim)



def time_stack(self, x_t):
B, F, T = x_t.size()[:-1]

direction_right = []
for i in range( F ):
direction_right.append(self.time_rnn1(x_t[:,i])[0].unsqueeze(1))
direction_right = torch.cat(direction_right, dim=1)

direction_up = []
for i in range( T ):
direction_up.append(self.time_rnn2(x_t[:,:,i])[0].unsqueeze(2))
direction_up = torch.cat(direction_up, dim=2)

direction_down = []
for i in range( T, 0, -1 ):
direction_down.append(self.time_rnn3(x_t[:,:,i-1])[0].unsqueeze(2))
direction_down = torch.cat(direction_down, dim=2)

out = torch.cat([direction_right, direction_up, direction_down], dim=-1)

out = self.time_linear(out) + x_t

return out


def freq_stack(self, x_t, x_f, x_c=None):
B, F, T = x_t.size()[:-1]

x_input = x_t + x_f
if x_c is not None:
x_input = x_input + x_c.unsqueeze(1)

direction_up = []
for i in range( T ):
direction_up.append(self.freq_rnn(x_input[:,:,i])[0].unsqueeze(2))
direction_up = torch.cat(direction_up, dim=2)

out = self.freq_linear(direction_up) + x_f

return out


def cent_stack(self, x_c):
B, T = x_c.size()[:-1]
out = self.cent_rnn(x_c)
out = self.cent_linear(out)
out = out + x_c
return out


def forward(self, x_t, x_f, x_c=None):
time_out = self.time_stack(x_t)
freq_out = self.freq_stack(time_out, x_f, x_c)

return time_out, freq_out



class UpperStack(nn.Module):
def __init__(self, hp):
super(UpperStack, self).__init__()
self.time_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, bidirectional=True, batch_first=True)
self.freq_rnn = nn.GRU(hp.hidden_dim, hp.hidden_dim, bidirectional=True, batch_first=True)

self.linear = nn.Linear(hp.hidden_dim*4, 1)

def forward(self, x):
time_out = self.time_rnn(x.transpose(1,2).contiguous())
freq_out = self.freq_rnn(x)

out = torch.cat([time_out, freq_out], dim=-1)
out = self.linear(out)

return out
Loading