-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
59 lines (51 loc) · 2.01 KB
/
main.py
File metadata and controls
59 lines (51 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
This file has the functions that drive all other functions/classes and interface
with user input.
"""
from config import *
from utils.globals import *
from utils.data import get_data, get_dataloaders
from models.training import get_grad_tts_model, load_latest_checkpt, \
train
from models.inference import convert_text_to_speech
from argparse import ArgumentParser
grad_tts_model = None
def run_training():
"""
Get the data, wrap it up in a Dataset and then a Dataloader. Then get the model,
and train it.
"""
global grad_tts_model
data = get_data()
dataloaders = get_dataloaders(data)
if grad_tts_model is None:
grad_tts_model = get_grad_tts_model()
train(grad_tts_model, dataloaders['train'], dataloaders['validation'])
def run_inference(text_or_file, from_file, out_file=None):
global grad_tts_model
if grad_tts_model is None:
grad_tts_model = get_grad_tts_model()
load_latest_checkpt(grad_tts_model)
convert_text_to_speech(text_or_file, grad_tts_model, out_file, \
from_file)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--train", help="Run training", action="store_true")
parser.add_argument("--tts", help="Text to convert to speech.")
parser.add_argument("--file", help="If set, the text passed in is interpreted as"\
" a file containing text to be converted - must be a .txt file.", action='store_true')
parser.add_argument("--out", help="Path to output file. If --file is passed, this"\
" can be left unspecified in which case X.txt will produce X.out.")
args = parser.parse_args()
something = False
if args.train:
something = True
run_training()
if args.tts is not None:
something = True
if args.out is None and not args.file:
print("Cannot leave out-file unspecified when passing text!")
exit(0)
run_inference(args.tts, args.file, args.out)
if not something:
parser.print_usage()