forked from bostxavier/Serial-Speakers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathxp_edition_mlm_params.py
More file actions
101 lines (87 loc) · 3.15 KB
/
xp_edition_mlm_params.py
File metadata and controls
101 lines (87 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Optional, Literal
import time
from more_itertools import flatten
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.commands import print_config
from sacred.run import Run
from sacred.utils import apply_backspaces_and_linefeeds
from tqdm import tqdm
from novelshare.hash import hash_tokens
from novelshare.align import align_tokens, make_plugin_mlm
from novelshare.experiments.data import (
iter_book_chapters,
normalize_,
EDITION_SETS,
)
from novelshare.experiments.metrics import log_alignment_metrics_
ex = Experiment()
ex.captured_out_filter = apply_backspaces_and_linefeeds # type: ignore
ex.observers.append(FileStorageObserver("runs"))
@ex.config
def config():
novel: str
window_range: list[int]
hash_len: int = 64
chapter_limit: Optional[int] = None
device: Literal["auto", "cuda", "cpu"] = "auto"
@ex.automain
def main(
_run: Run,
novel: str,
window_range: list[int],
hash_len: int,
chapter_limit: Optional[int],
device: Literal["auto", "cuda", "cpu"],
):
print_config(_run)
assert novel in EDITION_SETS
assert hash_len > 0 and hash_len <= 64
reference_edition = list(EDITION_SETS[novel].keys())[0]
reference_chapters = list(
iter_book_chapters(
EDITION_SETS[novel][reference_edition], chapter_limit=chapter_limit
)
)
wild_editions = {
key: list(iter_book_chapters(path, chapter_limit=chapter_limit))
for key, path in EDITION_SETS[novel].items()
if key != reference_edition
}
normalize_(reference_chapters)
for chapters in wild_editions.values():
normalize_(chapters)
progress = tqdm(total=len(wild_editions) * len(window_range), ascii=True)
for edition, user_tokens in wild_editions.items():
reference_hashed = [
hash_tokens(chapter, hash_len=hash_len) for chapter in reference_chapters
]
# If we have the same number of chapters, we assume that
# the chapters are aligned. Otherwise, we have to perform
# alignment on the entire novels, not on individual
# chapters (more costly!)
same_number_of_chapters = len(reference_chapters) == len(user_tokens)
if not same_number_of_chapters:
user_tokens = list(flatten(user_tokens))
reference_hashed = list(flatten(reference_hashed))
for window in window_range:
progress.set_description(f"{edition}.w={window}")
mlm = make_plugin_mlm("answerdotai/ModernBERT-base", window, device=device)
t0 = time.process_time()
aligned_tokens = align_tokens(
reference_hashed,
user_tokens,
hash_len=hash_len,
alignment_plugins=[mlm],
)
t1 = time.process_time()
reference_tokens = list(flatten(reference_chapters))
setup_name = f"w={window}.e={novel},{edition}"
log_alignment_metrics_(
_run,
setup_name,
reference_tokens,
aligned_tokens,
t1 - t0,
)
progress.update()