Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "b200"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
gpus_per_node = 8
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 4
num_gpus = 8
domain = "llm"
task = "pretrain"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "gb200"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 4
num_gpus = 8
domain = "llm"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "gb300"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 4
num_gpus = 8
domain = "llm"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "h100"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 8
num_gpus = 16
domain = "llm"
Expand Down
10 changes: 5 additions & 5 deletions doc/workloads/megatron_bridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ Test TOML example:
[cmd_args]
# Container can be an NGC/enroot URL (nvcr.io#...) or a local .sqsh path.
container_image = "nvcr.io#nvidia/nemo:25.11.01"

model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
task = "pretrain"
domain = "llm"
compute_dtype = "fp8_mx"
Expand Down Expand Up @@ -55,8 +55,8 @@ Test-in-Scenario example:

[Tests.cmd_args]
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
task = "pretrain"
domain = "llm"
compute_dtype = "fp8_mx"
Expand Down
87 changes: 82 additions & 5 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,11 +21,12 @@
import signal
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional
from unittest.mock import Mock

import toml
import yaml
from pydantic import ValidationError

from cloudai.core import (
BaseInstaller,
Expand All @@ -41,10 +42,10 @@
TestScenario,
)
from cloudai.models.scenario import ReportConfig
from cloudai.models.workload import TestDefinition
from cloudai.models.workload import TestDefinition, TestRun
from cloudai.parser import HOOK_ROOT
from cloudai.systems.slurm import SingleSbatchRunner, SlurmSystem
from cloudai.util import prepare_output_dir
from cloudai.util import flatten_dict, prepare_output_dir


def _log_installation_dirs(prefix: str, system: System) -> None:
Expand Down Expand Up @@ -145,7 +146,23 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
continue

env = CloudAIGymEnv(test_run=test_run, runner=runner.runner)
agent = agent_class(env)

try:
agent_config = test_run.test.agent_config
agent_overrides = (
validate_agent_overrides(test_run, agent_type, agent_config) if agent_config is not None else None
)
except ValidationError as e:
logging.error(f"Invalid agent_config for agent '{agent_type}': ")
for error in e.errors():
logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}")
logging.error("Valid overrides: ")
for item, desc in validate_agent_overrides(test_run, agent_type).items():
logging.error(f" - {item}: {desc}")
err = 1
continue
agent = agent_class(env, **agent_overrides) if agent_overrides is not None else agent_class(env)

for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
Expand All @@ -166,6 +183,66 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
return err


def validate_agent_overrides(
test_run: TestRun, agent_type: str, agent_config: Optional[dict[str, Any]] = None
) -> dict[str, Any]:
"""
Validate and process agent configuration overrides.

If agent_config is empty, returns the available configuration fields for the agent type.
"""
registry = Registry()
config_class_map = {}
for agent_name, agent_class in registry.agents_map.items():
if agent_class.config:
config_class_map[agent_name] = agent_class.config

config_class = config_class_map.get(agent_type)
if not config_class:
valid_types = ", ".join(f"'{agent_name}'" for agent_name in config_class_map)
raise ValueError(
f"Agent type '{agent_type}' does not support configuration overrides. "
f"Valid agent types are: {valid_types}. "
)

if agent_config:
seed_parameters = agent_config.get("seed_parameters", None)
if seed_parameters:
validate_seed_parameters(test_run, seed_parameters)

validated_config = config_class.model_validate(agent_config)
agent_kwargs = validated_config.model_dump(exclude_none=True)
logging.debug(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}")
else:
agent_kwargs = {}
for field_name, field_info in config_class.model_fields.items():
agent_kwargs[field_name] = field_info.description
return agent_kwargs


def validate_seed_parameters(test_run: TestRun, seed_parameters: dict[str, Any]) -> None:
"""Validate seed parameters against DSE-able command-line arguments."""
flat_cmd_args = flatten_dict(test_run.test.cmd_args.model_dump(exclude_none=True))
dse_cmd_args = {k: v for k, v in flat_cmd_args.items() if isinstance(v, list)}

logging.debug("Validating seed parameters against DSE-able command-line arguments:")
logging.debug(f"\t{dse_cmd_args}")

for key, value in seed_parameters.items():
if key not in dse_cmd_args:
raise KeyError(
f"Seed parameter '{key}' not found in DSE-able command-line arguments. "
f"Ensure that the key is one of the following available keys: {list(dse_cmd_args.keys())}"
)
if value not in dse_cmd_args[key]:
raise ValueError(
f"Seed parameter '{key}' value '{value}' not found in DSE-able command-line arguments. "
f"Ensure that the value is one of the following available values: {dse_cmd_args[key]}"
)

logging.debug("Seed parameters validated successfully.")


def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None:
registry = Registry()

Expand Down
8 changes: 6 additions & 2 deletions src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,7 +15,9 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

from cloudai.models.agent_config import AgentConfig

from .base_gym import BaseGym

Expand All @@ -28,6 +30,8 @@ class BaseAgent(ABC):
Automatically infers parameter types from TestRun's cmd_args.
"""

config: Optional[AgentConfig] = None

def __init__(self, env: BaseGym):
"""
Initialize the agent with the environment.
Expand Down
28 changes: 28 additions & 0 deletions src/cloudai/models/agent_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, Field


class AgentConfig(BaseModel, ABC):
"""Base configuration for agent overrides."""

model_config = ConfigDict(extra="forbid")
random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
seed_parameters: Optional[dict[str, Any]] = Field(default=None, description="Seed parameters for reproducibility")
3 changes: 2 additions & 1 deletion src/cloudai/models/workload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -107,6 +107,7 @@ class TestDefinition(BaseModel, ABC):
agent_steps: int = 1
agent_metrics: list[str] = Field(default=["default"])
agent_reward_function: str = "inverse"
agent_config: Optional[dict[str, Any]] = None

@property
def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]:
Expand Down
Loading
Loading