-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathmodal_test_app.py
More file actions
198 lines (160 loc) · 6.94 KB
/
modal_test_app.py
File metadata and controls
198 lines (160 loc) · 6.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python3
"""
Deploy locally-built pyllmq wheel to Modal and run recomputation tests.
Usage:
modal run modal_test_app.py [-- test args...]
"""
import io
import sys
from pathlib import Path
import modal
def create_image(cuda_version: str = "12.8.1"):
# Check where we can find some wheels
wheelhouse_exists = Path("wheelhouse").exists() and list(Path("wheelhouse").glob("pyllmq*.whl"))
dist_exists = Path("dist").exists() and list(Path("dist").glob("pyllmq*.whl"))
if wheelhouse_exists:
wheel_files = list(Path("wheelhouse").glob("pyllmq*.whl"))
elif dist_exists:
wheel_files = list(Path("dist").glob("pyllmq*.whl"))
else:
raise FileNotFoundError("No wheels found in wheelhouse/ or dist/")
if len(wheel_files) > 1:
print("Error: multiple wheels found:", file=sys.stderr)
for wheel_file in wheel_files:
print(f" {wheel_file}", file=sys.stderr)
sys.exit(1)
wheel_file = wheel_files[0]
print(f"Using wheel: {wheel_file}")
"""Create Modal image with the wheel and dependencies."""
return (
modal.Image.from_registry(f"nvidia/cuda:{cuda_version}-cudnn-devel-ubuntu24.04", add_python="3.12")
# install dependencies for pyllmq, so that rebuild steps are faster
.uv_pip_install("huggingface_hub", "transformers", "datasets", "numpy", "torch", "accelerate",
"nvidia-cuda-runtime-cu12>=12.8", "nvidia-cudnn-cu12>=9.0", "nvidia-nccl-cu12>=2.0", "nvidia-cublas-cu12>=12.8")
.run_commands("hf download Qwen/Qwen2.5-0.5B")
.add_local_file("scripts/tokenize_data.py", "/root/tokenize_data.py", copy=True)
.run_commands(
"uv run /root/tokenize_data.py --dataset tiny-shakespeare --model qwen --out-dir /root/data"
)
.add_local_file(str(wheel_file), f"/tmp/{wheel_file.name}", copy=True)
.pip_install(f"/tmp/{wheel_file.name}")
)
if modal.is_local():
image = create_image()
else:
image = modal.Image.debian_slim()
app = modal.App("llmq-test")
def compare_and_create_report(result, expected):
from pyllmq.tests.run import compare_results
report_buffer = io.StringIO()
passed = compare_results(result, expected, file=report_buffer)
report = report_buffer.getvalue()
return {
"passed": passed,
"report": report,
}
@app.function(
gpu="L4",
memory=8192,
timeout=300,
image=image,
)
def run_recompute_test(test_args: list[str]):
"""Run recomputation tests on Modal."""
from pyllmq.tests.run import parse_args, run_training
from pyllmq.tests.recompute import disable_recompute
# Parse test arguments into config
config = parse_args(test_args)
config.train_file = "/root/data/tiny-shakespeare-qwen/train.bin"
config.model = "Qwen/Qwen2.5-0.5B"
test_run = run_training(config)
ref_run = run_training(disable_recompute(config))
return compare_and_create_report(test_run, ref_run)
def run_with_config(test_args: list[str]):
from pyllmq.tests.run import parse_args, run_training
config = parse_args(test_args)
config.train_file = "/root/data/tiny-shakespeare-qwen/train.bin"
config.model = "Qwen/Qwen2.5-0.5B"
result = run_training(config)
return result
@app.function(
gpu="L4",
memory=8192,
timeout=300,
image=image,
)
def run_fixed_result_test(dtype: str = "bf16"):
from pyllmq.tests.run import RunResult
print(f"Launching Modal fixed_result test with dtype: {dtype}")
if dtype == "e5m2":
args = [f"--matmul-dtype=e4m3", "--gradient-dtype=e5m2"]
else:
args = [f"--matmul-dtype={dtype}"]
"""Run tests on Modal."""
result = run_with_config(args)
if dtype == "bf16":
expected = {
"losses": [3.6712355613708496, 3.3381927013397217, 3.4416379928588867, 3.414004325866699, 3.184972047805786, 3.5811686515808105, 3.303946018218994, 3.1412110328674316, 3.3189358711242676, 3.1411116123199463, 2.9200551509857178],
"norms": [7.183468818664551, 6.108993053436279, 5.932713031768799, 5.761547565460205, 6.189230918884277, 6.775102138519287, 5.833801746368408, 5.700308799743652, 5.810333728790283, 4.892214298248291],
}
elif dtype == "e4m3":
expected = {
"losses": [3.6887805461883545, 3.3520102500915527, 3.457583427429199, 3.4244065284729004, 3.193877696990967, 3.5914981365203857, 3.3114070892333984, 3.153069019317627, 3.327822208404541, 3.1491622924804688, 2.9285335540771484],
"norms": [7.288778305053711, 6.2950029373168945, 6.045900344848633, 6.014016628265381, 6.236341953277588, 6.587290287017822, 5.900122165679932, 5.704860687255859, 5.526883602142334, 4.838467597961426],
}
elif dtype == "e5m2":
expected = {
"losses": [3.6887805461883545, 3.352607488632202, 3.4565486907958984, 3.4266669750213623, 3.1923203468322754, 3.5887439250946045, 3.309535026550293, 3.1504156589508057, 3.329827070236206, 3.1509664058685303, 2.938549041748047],
"norms": [7.224483489990234, 6.2642621994018555, 6.038032054901123, 5.886837482452393, 6.202877521514893, 6.511382579803467, 5.829803466796875, 5.652224540710449, 5.565894603729248, 4.865762710571289],
}
else:
raise ValueError(f"Unknown dtype: {dtype}")
report = compare_and_create_report(result, RunResult(**expected))
if not report["passed"]:
import json
import dataclasses
# this helps with debugging/updating in case of failure
print(json.dumps(dataclasses.asdict(result)))
return report
@app.function(
gpu="L4",
memory=8192,
timeout=300,
image=image,
)
def run_torch_compare_step(test_args: list):
from pyllmq.tests.torch_reference import compare_single_step
from pyllmq.tests.run import parse_args
config = parse_args(test_args)
config.train_file = "/root/data/tiny-shakespeare-qwen/train.bin"
config.model = "Qwen/Qwen2.5-0.5B"
report_buffer = io.StringIO()
passed = compare_single_step(config, file=report_buffer)
report = report_buffer.getvalue()
return {
"passed": passed,
"report": report,
}
@app.local_entrypoint()
def test_recompute(*test_args: str):
print(f"Launching Modal recomputation test with args: {test_args}")
result = run_recompute_test.remote(list(test_args))
# Print the comparison report
print("\n" + result["report"])
if not result["passed"]:
sys.exit(1)
@app.local_entrypoint()
def test_torch_step(*test_args: str):
print(f"Launching Modal torch_compare_step test with args: {test_args}")
result = run_torch_compare_step.remote(list(test_args))
print("\n" + result["report"])
if not result["passed"]:
sys.exit(1)
@app.local_entrypoint()
def test_fixed(dtype: str = "bf16"):
print(f"Launching Modal test with dtype: {dtype}")
result = run_fixed_result_test.remote(dtype)
# Print the comparison report
print("\n" + result["report"])
if not result["passed"]:
sys.exit(1)