Skip to content

Commit 651ec9e

Browse files
committed
feat: Return structured training task results
Store and return structured results from training tasks, including run info, timing, metrics, parameters, tags, artifacts, and any errors. Previously, only the internal MLflow URL was returned (using the Docker service name for the host), limiting usability for programmatic access as well as users attempting to access the MLflow UI from outside the Docker network. Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
1 parent 6f0eaa7 commit 651ec9e

4 files changed

Lines changed: 132 additions & 31 deletions

File tree

cogstack_model_gateway/common/tracking.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
from datetime import UTC, datetime
34

45
import mlflow
56
import mlflow.models
@@ -74,6 +75,54 @@ def get_exceptions(self):
7475
else [value for key, value in self.data.tags.items() if key.startswith("exception_")]
7576
)
7677

78+
def to_dict(self) -> dict:
79+
"""Convert the tracking task to a structured results dictionary.
80+
81+
Returns a comprehensive dictionary containing run metadata, timing information,
82+
training metrics, parameters, tags, and artifact locations.
83+
84+
Returns:
85+
dict: Structured results with the following top-level keys:
86+
- run: Run identification and status
87+
- timing: Start/end times and duration
88+
- metrics: Training metrics (e.g. accuracy, loss)
89+
- params: Training parameters (e.g. learning_rate, epochs)
90+
- tags: Custom tags and metadata
91+
- artifacts: Artifact URIs
92+
- error: Logged exceptions, if any
93+
"""
94+
start_time_ms, end_time_ms = self.info.start_time, self.info.end_time
95+
96+
started_at = datetime.fromtimestamp(start_time_ms / 1000, tz=UTC).isoformat()
97+
finished_at, duration_seconds = None, None
98+
99+
if end_time_ms:
100+
finished_at = datetime.fromtimestamp(end_time_ms / 1000, tz=UTC).isoformat()
101+
duration_seconds = (end_time_ms - start_time_ms) / 1000
102+
103+
return {
104+
"run": {
105+
"run_id": self.info.run_id,
106+
"run_name": self.info.run_name,
107+
"experiment_id": self.info.experiment_id,
108+
"status": self.status,
109+
"lifecycle_stage": self.info.lifecycle_stage,
110+
"internal_url": self.url,
111+
},
112+
"timing": {
113+
"started_at": started_at,
114+
"finished_at": finished_at,
115+
"duration_seconds": duration_seconds,
116+
},
117+
"metrics": self.data.metrics or {},
118+
"params": self.data.params or {},
119+
"tags": self.data.tags or {},
120+
"artifacts": {
121+
"artifact_uri": self.info.artifact_uri,
122+
},
123+
"error": exceptions if (exceptions := self.get_exceptions()) else None,
124+
}
125+
77126

78127
class TrackingClient:
79128
def __init__(

cogstack_model_gateway/scheduler/scheduler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import logging
23
import time
34

@@ -140,11 +141,11 @@ def poll_task_status(self, task_uuid: str, tracking_id: str = None) -> dict:
140141
task = self.tracking_client.get_task(tracking_id)
141142
if task is None:
142143
raise ValueError(f"Task '{task_uuid}' not found in tracking server")
143-
res = {"url": task.url, "error": task.get_exceptions()}
144+
144145
if task.is_finished:
145-
return {"status": Status.SUCCEEDED, **res}
146+
return {"status": Status.SUCCEEDED, "results": task.to_dict()}
146147
elif task.is_failed or task.is_killed:
147-
return {"status": Status.FAILED, **res}
148+
return {"status": Status.FAILED, "results": task.to_dict()}
148149
else:
149150
# Task is scheduled or still running
150151
time.sleep(5)
@@ -266,18 +267,19 @@ def _handle_task_success(self, task_uuid: str, response: Response, ack: callable
266267

267268
results = self.poll_task_status(task_uuid, tracking_id)
268269
if results["status"] == Status.FAILED:
269-
log.error(f"Task '{task_uuid}' failed: {results['error']}")
270+
log.error(f"Task '{task_uuid}' failed: {results['results']['error']}")
270271
task = self.task_manager.update_task(
271-
task_uuid, status=Status.FAILED, error_message=str(results["error"])
272+
task_uuid, status=Status.FAILED, error_message=str(results["results"]["error"])
272273
)
273274
tasks_completed_total.labels(
274275
**get_task_labels(task), status=task.status.value
275276
).inc()
276277
return task
277278
else:
278279
log.info(f"Task '{task_uuid}' completed, writing results to object store")
280+
results_json = json.dumps(results["results"], indent=2)
279281
object_key = self.results_object_store_manager.upload_object(
280-
results["url"].encode(), "results.url", prefix=task_uuid
282+
results_json.encode(), "results.json", prefix=task_uuid
281283
)
282284
task = self.task_manager.update_task(
283285
task_uuid, status=Status.SUCCEEDED, result=object_key

tests/integration/test_api.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
count_deployed_model_containers,
2424
download_result_object,
2525
get_deployed_model_container,
26-
parse_mlflow_url,
2726
setup_cms,
2827
setup_scheduler,
2928
setup_testcontainers,
@@ -34,6 +33,7 @@
3433
verify_results_match_api_info,
3534
verify_task_payload_in_object_store,
3635
verify_task_submitted_successfully,
36+
verify_training_task_results,
3737
wait_for_task_completion,
3838
)
3939

@@ -116,8 +116,11 @@ def trained_model(client: TestClient, config: Config) -> tuple[str, str]:
116116
tm: TaskManager = config.task_manager
117117
task = wait_for_task_completion(response_json["uuid"], tm, expected_status=Status.SUCCEEDED)
118118

119-
_, parsed = download_result_object(task.result, config.results_object_store_manager, "text")
120-
_, _, run_id = parse_mlflow_url(parsed)
119+
_, parsed = download_result_object(task.result, config.results_object_store_manager)
120+
121+
verify_training_task_results(parsed, task)
122+
123+
run_id = parsed["run"]["run_id"]
121124

122125
tc: TrackingClient = config.tracking_client
123126
model_uri = tc.get_model_uri(run_id)
@@ -570,11 +573,9 @@ def test_train_supervised(client: TestClient, config: Config):
570573

571574
verify_queue_is_empty(config.queue_manager)
572575

573-
res, parsed = download_result_object(task.result, config.results_object_store_manager, "text")
574-
575-
_, _, run_id = parse_mlflow_url(parsed)
576-
assert run_id == task.tracking_id
576+
res, parsed = download_result_object(task.result, config.results_object_store_manager)
577577

578+
verify_training_task_results(parsed, task)
578579
verify_results_match_api_info(client, task, res)
579580

580581

@@ -599,11 +600,9 @@ def test_train_unsupervised(client: TestClient, config: Config):
599600

600601
verify_queue_is_empty(config.queue_manager)
601602

602-
res, parsed = download_result_object(task.result, config.results_object_store_manager, "text")
603-
604-
_, _, run_id = parse_mlflow_url(parsed)
605-
assert run_id == task.tracking_id
603+
res, parsed = download_result_object(task.result, config.results_object_store_manager)
606604

605+
verify_training_task_results(parsed, task)
607606
verify_results_match_api_info(client, task, res)
608607

609608

@@ -622,11 +621,9 @@ def test_train_unsupervised_with_hf_hub_dataset(client: TestClient, config: Conf
622621

623622
verify_queue_is_empty(config.queue_manager)
624623

625-
res, parsed = download_result_object(task.result, config.results_object_store_manager, "text")
626-
627-
_, _, run_id = parse_mlflow_url(parsed)
628-
assert run_id == task.tracking_id
624+
res, parsed = download_result_object(task.result, config.results_object_store_manager)
629625

626+
verify_training_task_results(parsed, task)
630627
verify_results_match_api_info(client, task, res)
631628

632629

@@ -651,11 +648,9 @@ def test_train_metacat(client: TestClient, config: Config):
651648

652649
verify_queue_is_empty(config.queue_manager)
653650

654-
res, parsed = download_result_object(task.result, config.results_object_store_manager, "text")
655-
656-
_, _, run_id = parse_mlflow_url(parsed)
657-
assert run_id == task.tracking_id
651+
res, parsed = download_result_object(task.result, config.results_object_store_manager)
658652

653+
verify_training_task_results(parsed, task)
659654
verify_results_match_api_info(client, task, res)
660655

661656

@@ -679,11 +674,9 @@ def test_evaluate(client: TestClient, config: Config):
679674

680675
verify_queue_is_empty(config.queue_manager)
681676

682-
res, parsed = download_result_object(task.result, config.results_object_store_manager, "text")
683-
684-
_, _, run_id = parse_mlflow_url(parsed)
685-
assert run_id == task.tracking_id
677+
res, parsed = download_result_object(task.result, config.results_object_store_manager)
686678

679+
verify_training_task_results(parsed, task)
687680
verify_results_match_api_info(client, task, res)
688681

689682

tests/integration/utils.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from docker.models.containers import Container
1414
from fastapi.testclient import TestClient
1515
from git import Repo
16+
from mlflow.entities import LifecycleStage, RunStatus
1617
from testcontainers.compose import DockerCompose
1718
from testcontainers.core.container import DockerClient, DockerContainer
1819
from testcontainers.minio import MinioContainer
@@ -108,8 +109,6 @@ def setup_scheduler(request: pytest.FixtureRequest):
108109

109110

110111
def setup_cms(request: pytest.FixtureRequest, cleanup_cms: bool) -> dict[str, dict]:
111-
cleanup_deployed_model_containers()
112-
113112
try:
114113
clone_cogstack_model_serve()
115114
except Exception as e:
@@ -440,6 +439,64 @@ def verify_results_match_api_info(client: TestClient, task: Task, result: bytes)
440439
assert download_results.content == result
441440

442441

442+
def verify_training_task_results(parsed_results: dict, task: Task):
443+
"""Verify structured results from a training task.
444+
445+
Validates that the results contain all expected top-level keys (run, timing, metrics, params,
446+
tags, artifacts) and that key fields match the task metadata.
447+
448+
Args:
449+
parsed_results: The parsed JSON results from the training task
450+
task: The Task object to verify against
451+
"""
452+
assert isinstance(parsed_results, dict)
453+
assert all(
454+
key in parsed_results
455+
for key in ["run", "timing", "metrics", "params", "tags", "artifacts", "error"]
456+
)
457+
458+
assert isinstance(parsed_results["run"], dict)
459+
assert all(
460+
key in parsed_results["run"]
461+
for key in [
462+
"run_id",
463+
"run_name",
464+
"experiment_id",
465+
"status",
466+
"lifecycle_stage",
467+
"internal_url",
468+
]
469+
)
470+
assert parsed_results["run"]["run_id"] == task.tracking_id
471+
assert parsed_results["run"]["status"] == RunStatus.to_string(RunStatus.FINISHED)
472+
assert parsed_results["run"]["lifecycle_stage"] == LifecycleStage.ACTIVE
473+
assert parsed_results["run"]["run_name"] is not None
474+
assert parsed_results["run"]["experiment_id"] is not None
475+
assert parsed_results["run"]["internal_url"] is not None
476+
477+
_, parsed_experiment_id, parsed_run_id = parse_mlflow_url(parsed_results["run"]["internal_url"])
478+
assert parsed_results["run"]["experiment_id"] == parsed_experiment_id
479+
assert parsed_results["run"]["run_id"] == parsed_run_id
480+
481+
assert isinstance(parsed_results["timing"], dict)
482+
assert all(
483+
key in parsed_results["timing"] for key in ["started_at", "finished_at", "duration_seconds"]
484+
)
485+
assert parsed_results["timing"]["started_at"] is not None
486+
assert parsed_results["timing"]["finished_at"] is not None
487+
assert parsed_results["timing"]["duration_seconds"] is not None
488+
489+
assert isinstance(parsed_results["metrics"], dict)
490+
assert isinstance(parsed_results["params"], dict)
491+
assert isinstance(parsed_results["tags"], dict)
492+
493+
assert isinstance(parsed_results["artifacts"], dict)
494+
assert "artifact_uri" in parsed_results["artifacts"]
495+
assert parsed_results["artifacts"]["artifact_uri"] is not None
496+
497+
assert parsed_results["error"] is None
498+
499+
443500
def parse_mlflow_url(url: str) -> tuple:
444501
response = requests.get(url)
445502
assert response.status_code == 200

0 commit comments

Comments
 (0)