-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
79 lines (58 loc) · 2.36 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
sys.path.append('waveglow/')
import numpy as np
import torch
import torchaudio
import random
import backend_helpers as helper
from hparams import create_hparams
from model import Tacotron2
from layers import TacotronSTFT, STFT
from audio_processing import griffin_lim
from train import load_model
from text import text_to_sequence
from waveglow.denoiser import Denoiser
from mutagen.mp3 import MP3
import logger
class Inference:
def __init__(self):
self.hparams = create_hparams()
self.checkpoint_path = "outdir/tacotron2_statedict_new.pt"
self.waveglow_path = 'model/waveglow_256channels.pt'
self.model = None
self.waveglow = None
self.denoiser = None
def load_model(self):
try:
self.hparams.sampling_rate = 22050
self.model = load_model(self.hparams)
self.model.load_state_dict(torch.load(self.checkpoint_path)['state_dict'])
_ = self.model.cuda().eval().half()
self.waveglow = torch.load(self.waveglow_path)['model']
#workaround for newer numpy library that uses padding mode
for m in self.waveglow.modules():
if 'Conv' in str(type(m)):
setattr(m, 'padding_mode', 'zeros')
self.waveglow.cuda().eval().half()
for k in self.waveglow.convinv:
k.float()
self.denoiser = Denoiser(self.waveglow)
except Exception as e:
logger.error("Inference failed", e)
def infer(self, sentence, job_text_id):
sequence = np.array(text_to_sequence(sentence, ['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(
torch.from_numpy(sequence)).cuda()
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
with torch.no_grad():
audio = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
filename = "tmp/" + job_text_id + '.mp3'
audio_denoised = self.denoiser(audio, strength=0.01)[:, 0]
audio_denoised = audio_denoised.data.cpu().float()
try:
torchaudio.save(filename, audio_denoised, self.hparams.sampling_rate)
audio = MP3(filename)
print(audio.info.length)
return audio.info.length
except Exception as e:
logger.error("Inference failed", e)