forked from flagos-ai/FlagScale
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
145 lines (113 loc) · 4.94 KB
/
run.py
File metadata and controls
145 lines (113 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import warnings
import hydra
from omegaconf import DictConfig, OmegaConf
from flagscale.logger import logger
from flagscale.runner.autotuner_factory import AutotunerFactory
from flagscale.runner.runner_base import Runner
from flagscale.runner.runner_inference import SSHInferenceRunner
from flagscale.runner.runner_serve import CloudServeRunner, SSHServeRunner
from flagscale.runner.runner_train import CloudTrainRunner, SSHTrainRunner
from flagscale.runner.utils import is_master
# To accommodate the scenario where the before_start field is used to switch to the actual environment during program execution,
# we have placed the import statements inside the function body rather than at the beginning of the file.
FLAGSCALE_USE_V1 = os.environ.get("FLAGSCALE_USE_V1", "1").lower() in ("1", "true")
VALID_TASKS = {"train", "inference", "compress", "serve", "rl"}
LEGACY_RUNNER_MAP = {
"train": SSHTrainRunner,
"inference": SSHInferenceRunner,
"serve": SSHServeRunner,
}
# task_type -> allowed actions
TASK_ACTIONS = {
"train": {"run", "dryrun", "test", "stop", "query", "auto_tune"},
"inference": {"run", "dryrun", "test", "stop"},
"serve": {"run", "test", "stop", "auto_tune"},
"compress": {"run", "dryrun", "stop"},
"rl": {"run", "dryrun", "test", "stop"},
}
def check_and_reset_deploy_config(config: DictConfig) -> None:
if config.experiment.get("deploy", {}):
OmegaConf.set_struct(config.experiment.runner, False)
config.experiment.runner.deploy = config.experiment.deploy
del config.experiment.deploy
warnings.warn(
"'config.experiment.deploy' has been moved to 'config.experiment.runner.deploy'. "
"Support for the old location will be removed in a future release."
)
OmegaConf.set_struct(config.experiment.runner, True)
def validate_task(task_type: str, action: str) -> None:
if task_type not in VALID_TASKS:
raise ValueError(f"Invalid task_type '{task_type}', must be one of {sorted(VALID_TASKS)}")
allowed_actions = TASK_ACTIONS[task_type]
if action not in allowed_actions:
raise ValueError(
f"Action '{action}' is not allowed for task_type '{task_type}'. "
f"Allowed actions: {sorted(allowed_actions)}"
)
def get_runner(config: DictConfig, task_type: str):
runner_type = config.experiment.runner.get("type", "ssh")
if runner_type == "cloud":
if task_type == "train":
return CloudTrainRunner(config)
elif task_type == "serve":
if FLAGSCALE_USE_V1:
return Runner(config)
else:
return CloudServeRunner(config)
else:
raise NotImplementedError(f"Task type '{task_type}' is not supported by cloud runner")
if FLAGSCALE_USE_V1:
return Runner(config)
logger.warning(
"Using legacy runner, which will be removed in future. Please use new runner instead."
)
assert task_type in LEGACY_RUNNER_MAP, (
f"Task type '{task_type}' is not supported by legacy runner"
)
return LEGACY_RUNNER_MAP[task_type](config)
def handle_auto_tune(config: DictConfig, task_type: str) -> None:
if task_type not in {"serve", "train"}:
raise NotImplementedError(f"Auto tune is not implemented for task type '{task_type}'")
# Only one autotuner process for MPI-based runs
if task_type == "train" and not is_master(config):
return
AutoTuner = AutotunerFactory.get_autotuner(task_type)
AutoTuner(config).tune()
def execute_action(runner, action: str, task_type: str, config: DictConfig) -> None:
if action == "run":
if task_type == "train":
enable_monitoring = config.experiment.runner.get("enable_monitoring", False)
enable_gpu_health_check = config.experiment.runner.get("enable_gpu_health_check", False)
runner.run(
enable_monitoring=enable_monitoring, enable_gpu_health_check=enable_gpu_health_check
)
if enable_monitoring:
logger.info("Monitor service will be started automatically when training begins.")
else:
runner.run()
elif action == "dryrun":
runner.run(dryrun=True)
elif action == "test":
runner.run(with_test=True)
elif action == "stop":
runner.stop()
elif action == "query":
runner.query()
else:
raise ValueError(f"Unknown action '{action}'")
@hydra.main(version_base=None, config_name="config")
def main(config: DictConfig) -> None:
check_and_reset_deploy_config(config)
task_type = config.experiment.task.get("type", None)
action = config.action
validate_task(task_type, action)
# auto_tune invokes the runner internally
if action == "auto_tune":
handle_auto_tune(config, task_type)
return
runner = get_runner(config, task_type)
execute_action(runner, action, task_type, config)
return
if __name__ == "__main__":
main()