-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
958dab2
commit 68b90d6
Showing
60 changed files
with
11,042 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# import audio.tools | ||
# import audio.stft | ||
# import audio.audio_processing | ||
from .stft import * | ||
from .audio_processing import * | ||
from .tools import * |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
import numpy as np | ||
import librosa.util as librosa_util | ||
from scipy.signal import get_window | ||
|
||
|
||
def window_sumsquare( | ||
window, | ||
n_frames, | ||
hop_length, | ||
win_length, | ||
n_fft, | ||
dtype=np.float32, | ||
norm=None, | ||
): | ||
""" | ||
# from librosa 0.6 | ||
Compute the sum-square envelope of a window function at a given hop length. | ||
This is used to estimate modulation effects induced by windowing | ||
observations in short-time fourier transforms. | ||
Parameters | ||
---------- | ||
window : string, tuple, number, callable, or list-like | ||
Window specification, as in `get_window` | ||
n_frames : int > 0 | ||
The number of analysis frames | ||
hop_length : int > 0 | ||
The number of samples to advance between frames | ||
win_length : [optional] | ||
The length of the window function. By default, this matches `n_fft`. | ||
n_fft : int > 0 | ||
The length of each analysis frame. | ||
dtype : np.dtype | ||
The data type of the output | ||
Returns | ||
------- | ||
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` | ||
The sum-squared envelope of the window function | ||
""" | ||
if win_length is None: | ||
win_length = n_fft | ||
|
||
n = n_fft + hop_length * (n_frames - 1) | ||
x = np.zeros(n, dtype=dtype) | ||
|
||
# Compute the squared window at the desired length | ||
win_sq = get_window(window, win_length, fftbins=True) | ||
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 | ||
win_sq = librosa_util.pad_center(win_sq, n_fft) | ||
|
||
# Fill the envelope | ||
for i in range(n_frames): | ||
sample = i * hop_length | ||
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] | ||
return x | ||
|
||
|
||
def griffin_lim(magnitudes, stft_fn, n_iters=30): | ||
""" | ||
PARAMS | ||
------ | ||
magnitudes: spectrogram magnitudes | ||
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods | ||
""" | ||
|
||
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) | ||
angles = angles.astype(np.float32) | ||
angles = torch.autograd.Variable(torch.from_numpy(angles)) | ||
signal = stft_fn.inverse(magnitudes, angles).squeeze(1) | ||
|
||
for i in range(n_iters): | ||
_, angles = stft_fn.transform(signal) | ||
signal = stft_fn.inverse(magnitudes, angles).squeeze(1) | ||
return signal | ||
|
||
|
||
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): | ||
""" | ||
PARAMS | ||
------ | ||
C: compression factor | ||
""" | ||
return normalize_fun(torch.clamp(x, min=clip_val) * C) | ||
|
||
|
||
def dynamic_range_decompression(x, C=1): | ||
""" | ||
PARAMS | ||
------ | ||
C: compression factor used to compress | ||
""" | ||
return torch.exp(x) / C |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from scipy.signal import get_window | ||
from librosa.util import pad_center, tiny | ||
from librosa.filters import mel as librosa_mel_fn | ||
|
||
from audioldm_eval.audio.audio_processing import ( | ||
dynamic_range_compression, | ||
dynamic_range_decompression, | ||
window_sumsquare, | ||
) | ||
|
||
|
||
class STFT(torch.nn.Module): | ||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" | ||
|
||
def __init__(self, filter_length, hop_length, win_length, window="hann"): | ||
super(STFT, self).__init__() | ||
self.filter_length = filter_length | ||
self.hop_length = hop_length | ||
self.win_length = win_length | ||
self.window = window | ||
self.forward_transform = None | ||
scale = self.filter_length / self.hop_length | ||
fourier_basis = np.fft.fft(np.eye(self.filter_length)) | ||
|
||
cutoff = int((self.filter_length / 2 + 1)) | ||
fourier_basis = np.vstack( | ||
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] | ||
) | ||
|
||
forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) | ||
inverse_basis = torch.FloatTensor( | ||
np.linalg.pinv(scale * fourier_basis).T[:, None, :] | ||
) | ||
|
||
if window is not None: | ||
assert filter_length >= win_length | ||
# get window and zero center pad it to filter_length | ||
fft_window = get_window(window, win_length, fftbins=True) | ||
fft_window = pad_center(fft_window, filter_length) | ||
fft_window = torch.from_numpy(fft_window).float() | ||
|
||
# window the bases | ||
forward_basis *= fft_window | ||
inverse_basis *= fft_window | ||
|
||
self.register_buffer("forward_basis", forward_basis.float()) | ||
self.register_buffer("inverse_basis", inverse_basis.float()) | ||
|
||
def transform(self, input_data): | ||
num_batches = input_data.size(0) | ||
num_samples = input_data.size(1) | ||
|
||
self.num_samples = num_samples | ||
|
||
# similar to librosa, reflect-pad the input | ||
input_data = input_data.view(num_batches, 1, num_samples) | ||
input_data = F.pad( | ||
input_data.unsqueeze(1), | ||
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), | ||
mode="reflect", | ||
) | ||
input_data = input_data.squeeze(1) | ||
|
||
forward_transform = F.conv1d( | ||
input_data, | ||
torch.autograd.Variable(self.forward_basis, requires_grad=False), | ||
stride=self.hop_length, | ||
padding=0, | ||
).cpu() | ||
|
||
cutoff = int((self.filter_length / 2) + 1) | ||
real_part = forward_transform[:, :cutoff, :] | ||
imag_part = forward_transform[:, cutoff:, :] | ||
|
||
magnitude = torch.sqrt(real_part**2 + imag_part**2) | ||
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) | ||
|
||
return magnitude, phase | ||
|
||
def inverse(self, magnitude, phase): | ||
recombine_magnitude_phase = torch.cat( | ||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 | ||
) | ||
|
||
inverse_transform = F.conv_transpose1d( | ||
recombine_magnitude_phase, | ||
torch.autograd.Variable(self.inverse_basis, requires_grad=False), | ||
stride=self.hop_length, | ||
padding=0, | ||
) | ||
|
||
if self.window is not None: | ||
window_sum = window_sumsquare( | ||
self.window, | ||
magnitude.size(-1), | ||
hop_length=self.hop_length, | ||
win_length=self.win_length, | ||
n_fft=self.filter_length, | ||
dtype=np.float32, | ||
) | ||
# remove modulation effects | ||
approx_nonzero_indices = torch.from_numpy( | ||
np.where(window_sum > tiny(window_sum))[0] | ||
) | ||
window_sum = torch.autograd.Variable( | ||
torch.from_numpy(window_sum), requires_grad=False | ||
) | ||
window_sum = window_sum | ||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ | ||
approx_nonzero_indices | ||
] | ||
|
||
# scale by hop ratio | ||
inverse_transform *= float(self.filter_length) / self.hop_length | ||
|
||
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] | ||
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] | ||
|
||
return inverse_transform | ||
|
||
def forward(self, input_data): | ||
self.magnitude, self.phase = self.transform(input_data) | ||
reconstruction = self.inverse(self.magnitude, self.phase) | ||
return reconstruction | ||
|
||
|
||
class TacotronSTFT(torch.nn.Module): | ||
def __init__( | ||
self, | ||
filter_length, | ||
hop_length, | ||
win_length, | ||
n_mel_channels, | ||
sampling_rate, | ||
mel_fmin, | ||
mel_fmax, | ||
): | ||
super(TacotronSTFT, self).__init__() | ||
self.n_mel_channels = n_mel_channels | ||
self.sampling_rate = sampling_rate | ||
self.stft_fn = STFT(filter_length, hop_length, win_length) | ||
mel_basis = librosa_mel_fn( | ||
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax | ||
) | ||
mel_basis = torch.from_numpy(mel_basis).float() | ||
self.register_buffer("mel_basis", mel_basis) | ||
|
||
def spectral_normalize(self, magnitudes, normalize_fun): | ||
output = dynamic_range_compression(magnitudes, normalize_fun) | ||
return output | ||
|
||
def spectral_de_normalize(self, magnitudes): | ||
output = dynamic_range_decompression(magnitudes) | ||
return output | ||
|
||
def mel_spectrogram(self, y, normalize_fun=torch.log): | ||
"""Computes mel-spectrograms from a batch of waves | ||
PARAMS | ||
------ | ||
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] | ||
RETURNS | ||
------- | ||
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) | ||
""" | ||
assert torch.min(y.data) >= -1 | ||
assert torch.max(y.data) <= 1 | ||
|
||
magnitudes, phases = self.stft_fn.transform(y) | ||
magnitudes = magnitudes.data | ||
mel_output = torch.matmul(self.mel_basis, magnitudes) | ||
mel_output = self.spectral_normalize(mel_output, normalize_fun) | ||
energy = torch.norm(magnitudes, dim=1) | ||
|
||
return mel_output, energy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import numpy as np | ||
from scipy.io.wavfile import write | ||
import pickle | ||
import json | ||
from audioldm_eval.audio.audio_processing import griffin_lim | ||
|
||
|
||
def save_pickle(obj, fname): | ||
print("Save pickle at " + fname) | ||
with open(fname, "wb") as f: | ||
pickle.dump(obj, f) | ||
|
||
|
||
def load_pickle(fname): | ||
print("Load pickle at " + fname) | ||
with open(fname, "rb") as f: | ||
res = pickle.load(f) | ||
return res | ||
|
||
|
||
def write_json(my_dict, fname): | ||
print("Save json file at " + fname) | ||
json_str = json.dumps(my_dict) | ||
with open(fname, "w") as json_file: | ||
json_file.write(json_str) | ||
|
||
|
||
def load_json(fname): | ||
with open(fname, "r") as f: | ||
data = json.load(f) | ||
return data | ||
|
||
|
||
def get_mel_from_wav(audio, _stft): | ||
audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) | ||
audio = torch.autograd.Variable(audio, requires_grad=False) | ||
melspec, energy = _stft.mel_spectrogram(audio) | ||
melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) | ||
energy = torch.squeeze(energy, 0).numpy().astype(np.float32) | ||
return melspec, energy | ||
|
||
|
||
def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): | ||
mel = torch.stack([mel]) | ||
mel_decompress = _stft.spectral_de_normalize(mel) | ||
mel_decompress = mel_decompress.transpose(1, 2).data.cpu() | ||
spec_from_mel_scaling = 1000 | ||
spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) | ||
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) | ||
spec_from_mel = spec_from_mel * spec_from_mel_scaling | ||
|
||
audio = griffin_lim( | ||
torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters | ||
) | ||
|
||
audio = audio.squeeze() | ||
audio = audio.cpu().numpy() | ||
audio_path = out_filename | ||
write(audio_path, _stft.sampling_rate, audio) |
Oops, something went wrong.