forked from atomicoo/FCH-TTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsynthesize.py
102 lines (85 loc) · 3.83 KB
/
synthesize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : train-duration.py
@Date : 2021/01/05, Tue
@Author : Atomicoo
@Version : 1.0
@Contact : [email protected]
@License : (C)Copyright 2020-2021, ShiGroup-NLP-XMU
@Desc : Synthetize sentences into speech.
'''
__author__ = 'Atomicoo'
import argparse
import os
import os.path as osp
import time
from scipy.io.wavfile import write
import torch
from utils.hparams import HParam
from utils.transform import StandardNorm
from helpers.synthesizer import Synthesizer
from vocoder_melgan.melgan import Generator
from datasets.dataset import TextProcessor
from models import ParallelText2Mel
from utils.utils import select_device, get_last_chkpt_path
try:
from helpers.manager import GPUManager
except ImportError as err:
print(err); gm = None
else:
gm = GPUManager()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--batch_size", default=8, type=int, help="Batch size")
parser.add_argument("--checkpoint", default=None, type=str, help="Checkpoint file path")
parser.add_argument("--melgan_checkpoint", default=None, type=str, help="Checkpoint file path of melgan")
parser.add_argument("--input_texts", default=None, type=str, help="Input text file path")
parser.add_argument("--outputs_dir", default=None, type=str, help="Output wave file directory")
parser.add_argument("--device", default=None, help="cuda device or cpu")
parser.add_argument("--name", default="parallel", type=str, help="Append to logdir name")
parser.add_argument("--config", default=None, type=str, help="Config file path")
args = parser.parse_args()
if torch.cuda.is_available():
index = args.device if args.device else str(0 if gm is None else gm.auto_choice())
else:
index = 'cpu'
device = select_device(index)
hparams = HParam(args.config) \
if args.config else HParam(osp.join(osp.abspath(os.getcwd()), "config", "default.yaml"))
logdir = osp.join(hparams.trainer.logdir, f"%s-%s" % (hparams.data.dataset, args.name))
checkpoint = args.checkpoint or get_last_chkpt_path(logdir)
normalizer = StandardNorm(hparams.audio.spec_mean, hparams.audio.spec_std)
processor = TextProcessor(hparams.text)
text2mel = ParallelText2Mel(hparams.parallel)
text2mel.eval()
synthesizer = Synthesizer(
model=text2mel,
checkpoint=checkpoint,
processor=processor,
normalizer=normalizer,
device=device
)
print('Synthesizing...')
since = time.time()
text_file = args.input_texts or hparams.synthesizer.inputs_file_path
with open(text_file, 'r', encoding='utf-8') as fr:
texts = fr.read().strip().split('\n')
melspecs = synthesizer.inference(texts)
print(f"Inference {len(texts)} spectrograms, total elapsed {time.time()-since:.3f}s. Done.")
vocoder = Generator(hparams.audio.n_mel_channels).to(device)
vocoder.eval(inference=True)
vocoder_checkpoint = args.melgan_checkpoint or \
osp.join(hparams.trainer.logdir, f"{hparams.data.dataset}-melgan", hparams.melgan.checkpoint)
vocoder.load_state_dict(torch.load(vocoder_checkpoint, map_location=device))
waves = vocoder(melspecs).squeeze(1)
print(f"Generate {len(texts)} audios, total elapsed {time.time()-since:.3f}s. Done.")
print('Saving audio...')
outputs_dir = args.outputs_dir or hparams.synthesizer.outputs_dir
os.makedirs(outputs_dir, exist_ok=True)
for i, wav in enumerate(waves, start=1):
wav = wav.cpu().detach().numpy()
filename = osp.join(outputs_dir, f"{time.strftime('%Y-%m-%d')}_{i:03d}.wav")
write(filename, hparams.audio.sampling_rate, wav)
print(f"Audios saved to {outputs_dir}. Done.")
print(f'Done. ({time.time()-since:.3f}s)')