forked from rhasspy/piper
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__.py
More file actions
95 lines (77 loc) · 2.76 KB
/
__main__.py
File metadata and controls
95 lines (77 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import json
import logging
from pathlib import Path
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from .vits.lightning import VitsModel
_LOGGER = logging.getLogger(__package__)
def main():
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset-dir", required=True, help="Path to pre-processed dataset directory"
)
parser.add_argument(
"--checkpoint-epochs",
type=int,
help="Save checkpoint every N epochs (default: 1)",
)
parser.add_argument(
"--quality",
default="medium",
choices=("x-low", "medium", "high"),
help="Quality/size of model (default: medium)",
)
Trainer.add_argparse_args(parser)
VitsModel.add_model_specific_args(parser)
parser.add_argument("--seed", type=int, default=1234)
args = parser.parse_args()
_LOGGER.debug(args)
args.dataset_dir = Path(args.dataset_dir)
if not args.default_root_dir:
args.default_root_dir = args.dataset_dir
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
config_path = args.dataset_dir / "config.json"
dataset_path = args.dataset_dir / "dataset.jsonl"
with open(config_path, "r", encoding="utf-8") as config_file:
# See preprocess.py for format
config = json.load(config_file)
num_symbols = int(config["num_symbols"])
num_speakers = int(config["num_speakers"])
sample_rate = int(config["audio"]["sample_rate"])
trainer = Trainer.from_argparse_args(args)
if args.checkpoint_epochs is not None:
trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
_LOGGER.debug(
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
)
dict_args = vars(args)
if args.quality == "x-low":
dict_args["hidden_channels"] = 96
dict_args["inter_channels"] = 96
dict_args["filter_channels"] = 384
elif args.quality == "high":
dict_args["resblock"] = "1"
dict_args["resblock_kernel_sizes"] = (3, 7, 11)
dict_args["resblock_dilation_sizes"] = (
(1, 3, 5),
(1, 3, 5),
(1, 3, 5),
)
dict_args["upsample_rates"] = (8, 8, 2, 2)
dict_args["upsample_initial_channel"] = 512
dict_args["upsample_kernel_sizes"] = (16, 16, 4, 4)
model = VitsModel(
num_symbols=num_symbols,
num_speakers=num_speakers,
sample_rate=sample_rate,
dataset=[dataset_path],
**dict_args
)
trainer.fit(model)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()