-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathutils.py
More file actions
169 lines (127 loc) · 4.28 KB
/
utils.py
File metadata and controls
169 lines (127 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import re
from pathlib import Path
from time import time
from typing import Self
import torch
from exllamav2 import (
ExLlamaV2Cache,
ExLlamaV2CacheBase,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
)
from rich import print
from rich.progress import (
BarColumn,
Progress as ProgressBar,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
)
from semantic_text_splitter import TextSplitter
from torchaudio import functional as F
class Progress:
def __init__(self, description: str, total: float = 1.0) -> None:
self.description = description
self.total = total
self.task = None
self.progress = ProgressBar(
TextColumn(f"[green]INFO[/green]:{' ' * 5}{{task.description}}"),
BarColumn(bar_width=None),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
)
def __enter__(self) -> Self:
self.progress.start()
self.task = self.progress.add_task(self.description, total=self.total)
return self
def __exit__(self, *args) -> None:
self.progress.update(self.task, completed=self.total)
self.progress.stop()
def __call__(self, *args, advance: float = 1.0) -> None:
self.progress.advance(self.task, advance)
class Timer:
def __init__(self) -> None:
self.start = 0.0
self.end = 0.0
self.total = 0.0
def __enter__(self) -> Self:
self.start = time()
return self
def __exit__(self, *args) -> None:
self.end = time()
self.total = self.end - self.start
def __call__(self, text: str, precision: int = 2) -> None:
log(f"{text} in {self.total:.{precision}f} seconds.")
def log(text: str) -> None:
print(f"[green]INFO[/green]:{' ' * 5}{text}")
def get_cache(cache: str) -> ExLlamaV2CacheBase:
match cache.lower():
case "q4":
return ExLlamaV2Cache_Q4
case "q6":
return ExLlamaV2Cache_Q6
case "q8":
return ExLlamaV2Cache_Q8
case _:
return ExLlamaV2Cache
def get_dtype(dtype: str) -> torch.dtype:
match dtype.lower():
case "fp16":
return torch.float16
case "bf16":
return torch.bfloat16
case _:
return torch.float32
def get_pairs(text: list[str] | str, default: str) -> list[dict[str, str]]:
if isinstance(text, str):
text = text.splitlines()
pattern = re.compile(r"(?:{{((?:[^{}]|{(?!{)|}(?!}))*)}})([^{]*)")
default = default.lower()
chunks = []
for chunk in text:
chunk = chunk.strip()
matches = list(re.finditer(pattern, chunk))
if not matches:
chunks.append({"voice": default, "text": chunk})
continue
for match in matches:
voice = match.group(1).strip().lower()
text = match.group(2).strip()
chunks.append({"voice": voice, "text": text})
return chunks
def process_audio(
audio: torch.Tensor, input_rate: int, output_rate: int, max_len: int = 0
) -> torch.Tensor:
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if max_len and audio.shape[1] / input_rate > max_len:
audio = audio[:, : input_rate * max_len]
if input_rate != output_rate:
audio = F.resample(audio, input_rate, output_rate)
return audio
def clean_text(text: list[str] | str) -> str:
if isinstance(text, list):
text = "\n".join(text)
lines = [" ".join(l.split()) for l in text.splitlines()]
lines = [l.strip() for l in lines if l.strip()]
return "\n".join(lines)
def process_text(text: str, suffixes: list[str] = [".txt"]) -> str:
path = Path(text)
if path.is_dir():
files = [f for f in path.glob("*.*") if f.suffix in suffixes]
text = [f.read_text(encoding="utf-8") for f in files]
elif path.is_file():
text = path.read_text(encoding="utf-8")
return clean_text(text)
def split_text(text: list[str] | str, max_len: int) -> list[str]:
if isinstance(text, str):
text = text.splitlines()
splitter = TextSplitter(max_len)
chunks = []
for line in text:
chunk = splitter.chunks(line)
chunks.extend(chunk)
return chunks