From 1012aa0e83a1de7966cdddb77773d8f5486b2636 Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Tue, 27 Jan 2026 13:46:40 -0800 Subject: [PATCH 01/22] agent hyperparam interface --- src/cloudai/cli/handlers.py | 51 +++++++++++++++++-- src/cloudai/models/agent_config.py | 79 ++++++++++++++++++++++++++++++ src/cloudai/models/workload.py | 3 +- 3 files changed, 129 insertions(+), 4 deletions(-) create mode 100644 src/cloudai/models/agent_config.py diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index d474ff421..df381c791 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,11 +21,12 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional from unittest.mock import Mock import toml import yaml +from pydantic import ValidationError from cloudai.core import ( BaseInstaller, @@ -40,6 +41,11 @@ TestParser, TestScenario, ) +from cloudai.models.agent_config import ( + BayesianOptimizationConfig, + GeneticAlgorithmConfig, + MultiArmedBanditConfig, +) from cloudai.models.scenario import ReportConfig from cloudai.models.workload import TestDefinition from cloudai.parser import HOOK_ROOT @@ -145,7 +151,19 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: continue env = CloudAIGymEnv(test_run=test_run, runner=runner.runner) - agent = agent_class(env) + + try: + agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) + except ValidationError as e: + logging.error(f"Invalid agent_config for agent '{agent_type}':") + for error in e.errors(): + field = ".".join(str(loc) for loc in error["loc"]) + logging.error(f" - {field}: {error['msg']}") + err = 1 + continue + + agent = agent_class(env, **agent_overrides) + for step in range(agent.max_steps): result = agent.select_action() if result is None: @@ -166,6 +184,33 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: return err +def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]]) -> dict[str, Any]: + """ + Validate and process agent configuration overrides. + """ + if not agent_config: + return {} + + config_class_map = { + "ga": GeneticAlgorithmConfig, + "bo": BayesianOptimizationConfig, + "mab": MultiArmedBanditConfig, + } + + config_class = config_class_map.get(agent_type) + if not config_class: + logging.debug(f"No config validation available for agent type '{agent_type}', using defaults.") + return {} + + validated_config = config_class.model_validate(agent_config) + agent_kwargs = validated_config.model_dump(exclude_none=True) + + if agent_kwargs: + logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") + + return agent_kwargs + + def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None: registry = Registry() diff --git a/src/cloudai/models/agent_config.py b/src/cloudai/models/agent_config.py new file mode 100644 index 000000000..6688baaa2 --- /dev/null +++ b/src/cloudai/models/agent_config.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class AgentConfig(BaseModel, ABC): + """ + Base configuration for agent overrides. + """ + + model_config = ConfigDict(extra="forbid") + random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") + +class GeneticAlgorithmConfig(AgentConfig): + """ + Configuration overrides for Genetic Algorithm agent. + """ + + population_size: Optional[int] = Field(default=None, ge=2, description="Population size for the genetic algorithm") + n_offsprings: Optional[int] = Field(default=None, ge=1, description="Number of offsprings per generation") + crossover_prob: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Crossover probability") + mutation_prob: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Mutation probability") + + +class BayesianOptimizationConfig(AgentConfig): + """ + Configuration overrides for Bayesian Optimization agent. + """ + + sobol_num_trials: Optional[int] = Field(default=None, ge=1, description="Number of SOBOL initialization trials") + botorch_num_trials: Optional[int] = Field( + default=None, description="Number of BoTorch trials (-1 for unlimited until max_steps)" + ) + +class MultiArmedBanditConfig(AgentConfig): + """ + Configuration overrides for Multi-Armed Bandit agent. + """ + + algorithm: Optional[str] = Field( + default=None, + description="MAB algorithm: ucb1, ts (thompson_sampling), epsilon_greedy, softmax, or random", + ) + algorithm_params: Optional[dict[str, Any]] = Field( + default=None, description="Algorithm-specific parameters (e.g., alpha for UCB1, epsilon for epsilon_greedy)" + ) + seed_parameters: Optional[dict[str, Any]] = Field( + default=None, description="Initial seed configuration to evaluate first" + ) + max_arms: Optional[int] = Field(default=None, ge=1, description="Maximum number of arms in the action space") + warm_start_size: Optional[int] = Field( + default=None, ge=0, description="Number of arms to randomly explore initially" + ) + epsilon_override: Optional[float] = Field( + default=None, ge=0.0, le=1.0, description="Epsilon value for exploration (overrides algorithm epsilon)" + ) + max_explore_steps: Optional[int] = Field( + default=None, ge=0, description="Maximum steps for epsilon exploration (None for unlimited)" + ) + prefer_unseen_random: Optional[bool] = Field( + default=None, description="Prefer unseen arms during random exploration (epsilon)" + ) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 1745ae734..0a962cf59 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -107,6 +107,7 @@ class TestDefinition(BaseModel, ABC): agent_steps: int = 1 agent_metrics: list[str] = Field(default=["default"]) agent_reward_function: str = "inverse" + agent_config: Optional[dict[str, Any]] = None @property def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]: From bd551e49cf05d2ce4cb9ae712c45e28bea1e4abf Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Tue, 27 Jan 2026 13:50:13 -0800 Subject: [PATCH 02/22] fix formatting --- src/cloudai/cli/handlers.py | 14 ++++++-------- src/cloudai/models/agent_config.py | 18 ++++++------------ 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index df381c791..674013f5e 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -151,7 +151,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: continue env = CloudAIGymEnv(test_run=test_run, runner=runner.runner) - + try: agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) except ValidationError as e: @@ -161,7 +161,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: logging.error(f" - {field}: {error['msg']}") err = 1 continue - + agent = agent_class(env, **agent_overrides) for step in range(agent.max_steps): @@ -185,9 +185,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]]) -> dict[str, Any]: - """ - Validate and process agent configuration overrides. - """ + """Validate and process agent configuration overrides.""" if not agent_config: return {} @@ -196,7 +194,7 @@ def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, A "bo": BayesianOptimizationConfig, "mab": MultiArmedBanditConfig, } - + config_class = config_class_map.get(agent_type) if not config_class: logging.debug(f"No config validation available for agent type '{agent_type}', using defaults.") @@ -204,10 +202,10 @@ def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, A validated_config = config_class.model_validate(agent_config) agent_kwargs = validated_config.model_dump(exclude_none=True) - + if agent_kwargs: logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") - + return agent_kwargs diff --git a/src/cloudai/models/agent_config.py b/src/cloudai/models/agent_config.py index 6688baaa2..3e090a622 100644 --- a/src/cloudai/models/agent_config.py +++ b/src/cloudai/models/agent_config.py @@ -21,17 +21,14 @@ class AgentConfig(BaseModel, ABC): - """ - Base configuration for agent overrides. - """ + """Base configuration for agent overrides.""" model_config = ConfigDict(extra="forbid") random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") + class GeneticAlgorithmConfig(AgentConfig): - """ - Configuration overrides for Genetic Algorithm agent. - """ + """Configuration overrides for Genetic Algorithm agent.""" population_size: Optional[int] = Field(default=None, ge=2, description="Population size for the genetic algorithm") n_offsprings: Optional[int] = Field(default=None, ge=1, description="Number of offsprings per generation") @@ -40,19 +37,16 @@ class GeneticAlgorithmConfig(AgentConfig): class BayesianOptimizationConfig(AgentConfig): - """ - Configuration overrides for Bayesian Optimization agent. - """ + """Configuration overrides for Bayesian Optimization agent.""" sobol_num_trials: Optional[int] = Field(default=None, ge=1, description="Number of SOBOL initialization trials") botorch_num_trials: Optional[int] = Field( default=None, description="Number of BoTorch trials (-1 for unlimited until max_steps)" ) + class MultiArmedBanditConfig(AgentConfig): - """ - Configuration overrides for Multi-Armed Bandit agent. - """ + """Configuration overrides for Multi-Armed Bandit agent.""" algorithm: Optional[str] = Field( default=None, From c1ecf2563120941d83c53b9d96150b211b39cdeb Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Tue, 27 Jan 2026 14:11:34 -0800 Subject: [PATCH 03/22] default pass no kwargs --- src/cloudai/cli/handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 674013f5e..c8ff8a74f 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -162,7 +162,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: err = 1 continue - agent = agent_class(env, **agent_overrides) + agent = agent_class(env, **agent_overrides) if agent_overrides else agent_class(env) for step in range(agent.max_steps): result = agent.select_action() From 9e7d6c35f5e59510fcbb335ca4754027fdb7c5ef Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Thu, 29 Jan 2026 13:27:52 -0800 Subject: [PATCH 04/22] update error logging --- src/cloudai/cli/handlers.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index c8ff8a74f..e7f7f205b 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -155,10 +155,12 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: try: agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) except ValidationError as e: - logging.error(f"Invalid agent_config for agent '{agent_type}':") - for error in e.errors(): - field = ".".join(str(loc) for loc in error["loc"]) - logging.error(f" - {field}: {error['msg']}") + items = ", ".join(str(loc) for error in e.errors() for loc in error["loc"]) + logging.error(f"Invalid agent_config for agent '{agent_type}': {items}") + valid_overrides = validate_agent_overrides(agent_type) + logging.error(f"Valid overrides: ") + for item in valid_overrides.items(): + logging.error(f" - {item[0]}: {item[1]}") err = 1 continue @@ -184,21 +186,30 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: return err -def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]]) -> dict[str, Any]: - """Validate and process agent configuration overrides.""" - if not agent_config: - return {} - +def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]] = None) -> dict[str, Any]: + """ + Validate and process agent configuration overrides. + If agent_config is empty, returns the available configuration fields for the agent type. + """ config_class_map = { "ga": GeneticAlgorithmConfig, - "bo": BayesianOptimizationConfig, + "bo_gp": BayesianOptimizationConfig, "mab": MultiArmedBanditConfig, } config_class = config_class_map.get(agent_type) if not config_class: - logging.debug(f"No config validation available for agent type '{agent_type}', using defaults.") - return {} + valid_types = ", ".join(f"'{t}'" for t in config_class_map.keys()) + raise ValueError( + f"Agent type '{agent_type}' does not support configuration overrides. " + f"Valid agent types are: {valid_types}. " + ) + + if not agent_config: + available_overrides = {} + for field_name, field_info in config_class.model_fields.items(): + available_overrides[field_name] = field_info.description + return available_overrides validated_config = config_class.model_validate(agent_config) agent_kwargs = validated_config.model_dump(exclude_none=True) From e7fee0c10a445310f0eb0c8722d94792aeb4258b Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Thu, 29 Jan 2026 13:28:47 -0800 Subject: [PATCH 05/22] fix docstring --- src/cloudai/cli/handlers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index e7f7f205b..44a506f66 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -158,7 +158,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: items = ", ".join(str(loc) for error in e.errors() for loc in error["loc"]) logging.error(f"Invalid agent_config for agent '{agent_type}': {items}") valid_overrides = validate_agent_overrides(agent_type) - logging.error(f"Valid overrides: ") + logging.error("Valid overrides: ") for item in valid_overrides.items(): logging.error(f" - {item[0]}: {item[1]}") err = 1 @@ -189,6 +189,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]] = None) -> dict[str, Any]: """ Validate and process agent configuration overrides. + If agent_config is empty, returns the available configuration fields for the agent type. """ config_class_map = { @@ -199,7 +200,7 @@ def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, A config_class = config_class_map.get(agent_type) if not config_class: - valid_types = ", ".join(f"'{t}'" for t in config_class_map.keys()) + valid_types = ", ".join(f"'{t}'" for t in config_class_map) raise ValueError( f"Agent type '{agent_type}' does not support configuration overrides. " f"Valid agent types are: {valid_types}. " From cfb38f5c8568ccad9299c0f2421bf28ceb12b448 Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Fri, 30 Jan 2026 16:17:57 -0800 Subject: [PATCH 06/22] make agent interface abstract --- src/cloudai/cli/handlers.py | 45 ++++++++++-------------- src/cloudai/configurator/base_agent.py | 6 +++- src/cloudai/models/agent_config.py | 48 +------------------------- 3 files changed, 24 insertions(+), 75 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 44a506f66..475092451 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -41,11 +41,6 @@ TestParser, TestScenario, ) -from cloudai.models.agent_config import ( - BayesianOptimizationConfig, - GeneticAlgorithmConfig, - MultiArmedBanditConfig, -) from cloudai.models.scenario import ReportConfig from cloudai.models.workload import TestDefinition from cloudai.parser import HOOK_ROOT @@ -155,12 +150,12 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: try: agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) except ValidationError as e: - items = ", ".join(str(loc) for error in e.errors() for loc in error["loc"]) - logging.error(f"Invalid agent_config for agent '{agent_type}': {items}") - valid_overrides = validate_agent_overrides(agent_type) + logging.error(f"Invalid agent_config for agent '{agent_type}': ") + for error in e.errors(): + logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}") logging.error("Valid overrides: ") - for item in valid_overrides.items(): - logging.error(f" - {item[0]}: {item[1]}") + for item, desc in validate_agent_overrides(agent_type).items(): + logging.error(f" - {item}: {desc}") err = 1 continue @@ -192,32 +187,28 @@ def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, A If agent_config is empty, returns the available configuration fields for the agent type. """ - config_class_map = { - "ga": GeneticAlgorithmConfig, - "bo_gp": BayesianOptimizationConfig, - "mab": MultiArmedBanditConfig, - } + registry = Registry() + config_class_map = {} + for agent_name, agent_class in registry.agents_map.items(): + if agent_class.config: + config_class_map[agent_name] = agent_class.config config_class = config_class_map.get(agent_type) if not config_class: - valid_types = ", ".join(f"'{t}'" for t in config_class_map) + valid_types = ", ".join(f"'{agent_name}'" for agent_name in config_class_map) raise ValueError( f"Agent type '{agent_type}' does not support configuration overrides. " f"Valid agent types are: {valid_types}. " ) - if not agent_config: - available_overrides = {} - for field_name, field_info in config_class.model_fields.items(): - available_overrides[field_name] = field_info.description - return available_overrides - - validated_config = config_class.model_validate(agent_config) - agent_kwargs = validated_config.model_dump(exclude_none=True) - - if agent_kwargs: + if agent_config: + validated_config = config_class.model_validate(agent_config) + agent_kwargs = validated_config.model_dump(exclude_none=True) logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") - + else: + agent_kwargs = {} + for field_name, field_info in config_class.model_fields.items(): + agent_kwargs[field_name] = field_info.description return agent_kwargs diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index dbd397099..4b806a53c 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -15,7 +15,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple + +from cloudai.models.agent_config import AgentConfig from .base_gym import BaseGym @@ -28,6 +30,8 @@ class BaseAgent(ABC): Automatically infers parameter types from TestRun's cmd_args. """ + config: Optional[AgentConfig] = None + def __init__(self, env: BaseGym): """ Initialize the agent with the environment. diff --git a/src/cloudai/models/agent_config.py b/src/cloudai/models/agent_config.py index 3e090a622..0b04059aa 100644 --- a/src/cloudai/models/agent_config.py +++ b/src/cloudai/models/agent_config.py @@ -15,7 +15,7 @@ # limitations under the License. from abc import ABC -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel, ConfigDict, Field @@ -25,49 +25,3 @@ class AgentConfig(BaseModel, ABC): model_config = ConfigDict(extra="forbid") random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") - - -class GeneticAlgorithmConfig(AgentConfig): - """Configuration overrides for Genetic Algorithm agent.""" - - population_size: Optional[int] = Field(default=None, ge=2, description="Population size for the genetic algorithm") - n_offsprings: Optional[int] = Field(default=None, ge=1, description="Number of offsprings per generation") - crossover_prob: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Crossover probability") - mutation_prob: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Mutation probability") - - -class BayesianOptimizationConfig(AgentConfig): - """Configuration overrides for Bayesian Optimization agent.""" - - sobol_num_trials: Optional[int] = Field(default=None, ge=1, description="Number of SOBOL initialization trials") - botorch_num_trials: Optional[int] = Field( - default=None, description="Number of BoTorch trials (-1 for unlimited until max_steps)" - ) - - -class MultiArmedBanditConfig(AgentConfig): - """Configuration overrides for Multi-Armed Bandit agent.""" - - algorithm: Optional[str] = Field( - default=None, - description="MAB algorithm: ucb1, ts (thompson_sampling), epsilon_greedy, softmax, or random", - ) - algorithm_params: Optional[dict[str, Any]] = Field( - default=None, description="Algorithm-specific parameters (e.g., alpha for UCB1, epsilon for epsilon_greedy)" - ) - seed_parameters: Optional[dict[str, Any]] = Field( - default=None, description="Initial seed configuration to evaluate first" - ) - max_arms: Optional[int] = Field(default=None, ge=1, description="Maximum number of arms in the action space") - warm_start_size: Optional[int] = Field( - default=None, ge=0, description="Number of arms to randomly explore initially" - ) - epsilon_override: Optional[float] = Field( - default=None, ge=0.0, le=1.0, description="Epsilon value for exploration (overrides algorithm epsilon)" - ) - max_explore_steps: Optional[int] = Field( - default=None, ge=0, description="Maximum steps for epsilon exploration (None for unlimited)" - ) - prefer_unseen_random: Optional[bool] = Field( - default=None, description="Prefer unseen arms during random exploration (epsilon)" - ) From ab3df179b21e3b4140be0092c5facb8006cb8fea Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Tue, 3 Feb 2026 08:38:20 -0800 Subject: [PATCH 07/22] fix copyright --- src/cloudai/configurator/base_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index 4b806a53c..a82502217 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); From 4d2667f73c8de132311f4710d314567b8b3de986 Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Tue, 3 Feb 2026 09:18:01 -0800 Subject: [PATCH 08/22] better kwargs handling --- src/cloudai/cli/handlers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 475092451..b787293d3 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -148,7 +148,11 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: env = CloudAIGymEnv(test_run=test_run, runner=runner.runner) try: - agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) + agent_overrides = ( + validate_agent_overrides(agent_type, test_run.test.agent_config) + if test_run.test.agent_config is not None + else None + ) except ValidationError as e: logging.error(f"Invalid agent_config for agent '{agent_type}': ") for error in e.errors(): @@ -158,8 +162,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: logging.error(f" - {item}: {desc}") err = 1 continue - - agent = agent_class(env, **agent_overrides) if agent_overrides else agent_class(env) + agent = agent_class(env, **agent_overrides) if agent_overrides is not None else agent_class(env) for step in range(agent.max_steps): result = agent.select_action() From 09ee673c788ecf5a17b5c6802371e654c330127b Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Wed, 11 Feb 2026 12:35:19 -0800 Subject: [PATCH 09/22] add agent seed parameter support/validation --- src/cloudai/cli/handlers.py | 46 ++++++++++++++++++++++++------ src/cloudai/models/agent_config.py | 3 +- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index b787293d3..7d440ffbb 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -42,10 +42,10 @@ TestScenario, ) from cloudai.models.scenario import ReportConfig -from cloudai.models.workload import TestDefinition +from cloudai.models.workload import TestDefinition, TestRun from cloudai.parser import HOOK_ROOT from cloudai.systems.slurm import SingleSbatchRunner, SlurmSystem -from cloudai.util import prepare_output_dir +from cloudai.util import flatten_dict, prepare_output_dir def _log_installation_dirs(prefix: str, system: System) -> None: @@ -148,17 +148,16 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: env = CloudAIGymEnv(test_run=test_run, runner=runner.runner) try: + agent_config = test_run.test.agent_config agent_overrides = ( - validate_agent_overrides(agent_type, test_run.test.agent_config) - if test_run.test.agent_config is not None - else None + validate_agent_overrides(test_run, agent_type, agent_config) if agent_config is not None else None ) except ValidationError as e: logging.error(f"Invalid agent_config for agent '{agent_type}': ") for error in e.errors(): logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}") logging.error("Valid overrides: ") - for item, desc in validate_agent_overrides(agent_type).items(): + for item, desc in validate_agent_overrides(test_run, agent_type).items(): logging.error(f" - {item}: {desc}") err = 1 continue @@ -184,7 +183,9 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: return err -def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]] = None) -> dict[str, Any]: +def validate_agent_overrides( + test_run: TestRun, agent_type: str, agent_config: Optional[dict[str, Any]] = None +) -> dict[str, Any]: """ Validate and process agent configuration overrides. @@ -205,9 +206,14 @@ def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, A ) if agent_config: + seed_parameters = agent_config.pop("seed_parameters", None) + if seed_parameters: + valid_seed_parameters = validate_seed_parameters(test_run, seed_parameters) + agent_config["seed_parameters"] = valid_seed_parameters + validated_config = config_class.model_validate(agent_config) agent_kwargs = validated_config.model_dump(exclude_none=True) - logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") + logging.debug(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") else: agent_kwargs = {} for field_name, field_info in config_class.model_fields.items(): @@ -215,6 +221,30 @@ def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, A return agent_kwargs +def validate_seed_parameters(test_run: TestRun, seed_parameters: dict[str, Any]) -> dict[str, Any]: + """Validate seed parameters against DSE-able command-line arguments.""" + flat_cmd_args = flatten_dict(test_run.test.cmd_args.model_dump(exclude_none=True)) + dse_cmd_args = {k: v for k, v in flat_cmd_args.items() if isinstance(v, list)} + + logging.debug("Validating seed parameters against DSE-able command-line arguments:") + logging.debug(f"\t{dse_cmd_args}") + + for key, value in seed_parameters.items(): + if key not in dse_cmd_args: + raise KeyError( + f"Seed parameter '{key}' not found in DSE-able command-line arguments. " + f"Ensure that the key is one of the following available keys: {dse_cmd_args.keys()}" + ) + if value not in dse_cmd_args[key]: + raise ValueError( + f"Seed parameter '{key}' value '{value}' not found in DSE-able command-line arguments. " + f"Ensure that the value is one of the following available values: {dse_cmd_args[key]}" + ) + + logging.debug("Seed parameters validated successfully.") + return seed_parameters + + def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None: registry = Registry() diff --git a/src/cloudai/models/agent_config.py b/src/cloudai/models/agent_config.py index 0b04059aa..072bb928e 100644 --- a/src/cloudai/models/agent_config.py +++ b/src/cloudai/models/agent_config.py @@ -15,7 +15,7 @@ # limitations under the License. from abc import ABC -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field @@ -25,3 +25,4 @@ class AgentConfig(BaseModel, ABC): model_config = ConfigDict(extra="forbid") random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") + seed_parameters: Optional[dict[str, Any]] = Field(default=None, description="Seed parameters for reproducibility") From db8f259242d4f042629fdaa1ee85b51631ed5333 Mon Sep 17 00:00:00 2001 From: Alex Manley Date: Wed, 11 Feb 2026 13:01:52 -0800 Subject: [PATCH 10/22] fix unneeded dict modification and error printing --- src/cloudai/cli/handlers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 7d440ffbb..b295059b6 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -206,10 +206,9 @@ def validate_agent_overrides( ) if agent_config: - seed_parameters = agent_config.pop("seed_parameters", None) + seed_parameters = agent_config.get("seed_parameters", None) if seed_parameters: - valid_seed_parameters = validate_seed_parameters(test_run, seed_parameters) - agent_config["seed_parameters"] = valid_seed_parameters + validate_seed_parameters(test_run, seed_parameters) validated_config = config_class.model_validate(agent_config) agent_kwargs = validated_config.model_dump(exclude_none=True) @@ -221,7 +220,7 @@ def validate_agent_overrides( return agent_kwargs -def validate_seed_parameters(test_run: TestRun, seed_parameters: dict[str, Any]) -> dict[str, Any]: +def validate_seed_parameters(test_run: TestRun, seed_parameters: dict[str, Any]) -> None: """Validate seed parameters against DSE-able command-line arguments.""" flat_cmd_args = flatten_dict(test_run.test.cmd_args.model_dump(exclude_none=True)) dse_cmd_args = {k: v for k, v in flat_cmd_args.items() if isinstance(v, list)} @@ -233,7 +232,7 @@ def validate_seed_parameters(test_run: TestRun, seed_parameters: dict[str, Any]) if key not in dse_cmd_args: raise KeyError( f"Seed parameter '{key}' not found in DSE-able command-line arguments. " - f"Ensure that the key is one of the following available keys: {dse_cmd_args.keys()}" + f"Ensure that the key is one of the following available keys: {list(dse_cmd_args.keys())}" ) if value not in dse_cmd_args[key]: raise ValueError( @@ -242,7 +241,6 @@ def validate_seed_parameters(test_run: TestRun, seed_parameters: dict[str, Any]) ) logging.debug("Seed parameters validated successfully.") - return seed_parameters def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None: From 8583261742295452516b44b20bedebf20eca280e Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Wed, 7 Jan 2026 15:14:45 -0800 Subject: [PATCH 11/22] m-bridge for nemo container 26.02 --- .../test/b200/megatron_bridge_qwen_30b.toml | 6 +- .../test/gb200/megatron_bridge_qwen_30b.toml | 4 +- .../test/gb300/megatron_bridge_qwen_30b.toml | 4 +- .../test/h100/megatron_bridge_qwen_30b.toml | 4 +- doc/workloads/megatron_bridge.rst | 10 +-- .../megatron_bridge/megatron_bridge.py | 51 ++++++++++- .../slurm_command_gen_strategy.py | 82 +++++++++++++---- ...atron_bridge_slurm_command_gen_strategy.py | 88 ++++++++++++++++--- 8 files changed, 204 insertions(+), 45 deletions(-) diff --git a/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml index c93eac45a..c67f918c4 100644 --- a/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml @@ -28,9 +28,9 @@ mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "b200" container_image = "nvcr.io#nvidia/nemo:25.11.01" -model_name = "qwen3" -model_size = "30b_a3b" -gpus_per_node = 8 +model_family_name = "qwen3" +model_recipe_name = "30b_a3b" +gpus_per_node = 4 num_gpus = 8 domain = "llm" task = "pretrain" diff --git a/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml index 8ed62ead3..8802bb4b7 100644 --- a/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "gb200" container_image = "nvcr.io#nvidia/nemo:25.11.01" -model_name = "qwen3" -model_size = "30b_a3b" +model_family_name = "qwen3" +model_recipe_name = "30b_a3b" gpus_per_node = 4 num_gpus = 8 domain = "llm" diff --git a/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml index 85f30ca16..9fc2db746 100644 --- a/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "gb300" container_image = "nvcr.io#nvidia/nemo:25.11.01" -model_name = "qwen3" -model_size = "30b_a3b" +model_family_name = "qwen3" +model_recipe_name = "30b_a3b" gpus_per_node = 4 num_gpus = 8 domain = "llm" diff --git a/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml index f8a397973..4a556fc84 100644 --- a/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "h100" container_image = "nvcr.io#nvidia/nemo:25.11.01" -model_name = "qwen3" -model_size = "30b_a3b" +model_family_name = "qwen3" +model_recipe_name = "30b_a3b" gpus_per_node = 8 num_gpus = 16 domain = "llm" diff --git a/doc/workloads/megatron_bridge.rst b/doc/workloads/megatron_bridge.rst index 8cbb18a78..7be785858 100644 --- a/doc/workloads/megatron_bridge.rst +++ b/doc/workloads/megatron_bridge.rst @@ -18,9 +18,9 @@ Test TOML example: [cmd_args] # Container can be an NGC/enroot URL (nvcr.io#...) or a local .sqsh path. container_image = "nvcr.io#nvidia/nemo:25.11.01" - - model_name = "qwen3" - model_size = "30b_a3b" + + model_family_name = "qwen3" + model_recipe_name = "30b_a3b" task = "pretrain" domain = "llm" compute_dtype = "fp8_mx" @@ -55,8 +55,8 @@ Test-in-Scenario example: [Tests.cmd_args] container_image = "nvcr.io#nvidia/nemo:25.11.01" - model_name = "qwen3" - model_size = "30b_a3b" + model_family_name = "qwen3" + model_recipe_name = "30b_a3b" task = "pretrain" domain = "llm" compute_dtype = "fp8_mx" diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index ef07b8b29..c472d9e93 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -40,8 +40,9 @@ class MegatronBridgeCmdArgs(CmdArgs): detach: Optional[bool] = Field(default=None) # Model/task - model_name: str = Field(min_length=1) - model_size: str = Field(min_length=1) + model_family_name: str = Field(default="") + model_recipe_name: str = Field(default="") + use_recipes: Optional[bool] = Field(default=None) domain: str = Field(default="llm") task: str = Field(default="pretrain") compute_dtype: str = Field(default="bf16") @@ -49,8 +50,13 @@ class MegatronBridgeCmdArgs(CmdArgs): hf_token: Optional[str] = Field(default=None) nemo_home: Optional[str] = Field(default=None) wandb_key: Optional[str] = Field(default=None) - wandb_prj_name: Optional[str] = Field(default=None) - wandb_exp_name: Optional[str] = Field(default=None) + wandb_project_name: Optional[str] = Field(default=None) + wandb_entity_name: Optional[str] = Field(default=None) + wandb_experiment_name: Optional[str] = Field(default=None) + wandb_save_dir: Optional[str] = Field(default=None) + + # Retries + max_retries: Optional[int] = Field(default=None) # Feature flags (allow sweeps) use_tokendrop: Optional[Union[bool, List[bool]]] = Field(default=None) @@ -69,6 +75,43 @@ class MegatronBridgeCmdArgs(CmdArgs): # Batch sizes mb: Optional[Union[int, List[int]]] = Field(default=None) gb: Optional[Union[int, List[int]]] = Field(default=None) + seq_length: Optional[Union[int, List[int]]] = Field(default=None) + + # Optimizer + lr: Optional[Union[float, List[float]]] = Field(default=None) + min_lr: Optional[Union[float, List[float]]] = Field(default=None) + warmup_iters: Optional[Union[int, List[int]]] = Field(default=None) + + # Checkpointing + pretrained_checkpoint: Optional[str] = Field(default=None) + save_dir: Optional[str] = Field(default=None) + load_dir: Optional[str] = Field(default=None) + save_interval: Optional[int] = Field(default=None) + most_recent_k: Optional[int] = Field(default=None) + save_config_filepath: Optional[str] = Field(default=None) + + # Data / Tokenizer + data: Optional[str] = Field(default=None) + dataset_paths: Optional[Union[str, List[str]]] = Field(default=None) + dataset_root: Optional[str] = Field(default=None) + index_mapping_dir: Optional[str] = Field(default=None) + dataset_name: Optional[str] = Field(default=None) + packed_sequence: Optional[bool] = Field(default=None) + head_only: Optional[bool] = Field(default=None) + tokenizer_type: Optional[str] = Field(default=None) + tokenizer_model: Optional[str] = Field(default=None) + vocab_size: Optional[int] = Field(default=None) + + # Profiling (performance group in argument_parser.py) + pytorch_profiler: Optional[bool] = Field(default=None) + profiling_start_step: Optional[int] = Field(default=None) + profiling_stop_step: Optional[int] = Field(default=None) + record_memory_history: Optional[bool] = Field(default=None) + profiling_gpu_metrics: Optional[bool] = Field(default=None) + profiling_ranks: Optional[Union[int, List[int]]] = Field(default=None) + + # Performance + nccl_ub: Optional[Union[bool, List[bool]]] = Field(default=None) # Perf/tuning moe_a2a_overlap: Optional[Union[bool, List[bool]]] = Field(default=None) diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index c1f4ff287..5a896be37 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -167,8 +167,8 @@ def _build_launcher_parts( # noqa: C901 ) -> list[str]: fields_set = args.model_fields_set force_fields = { - "model_name", - "model_size", + "model_family_name", + "model_recipe_name", "num_gpus", "gpus_per_node", "hf_token", @@ -211,6 +211,15 @@ def add(flag: str, value: Any) -> None: return if isinstance(value, bool): parts.extend([flag, "true" if value else "false"]) + elif isinstance(value, (list, tuple)): + if not value: + return + if flag == "--dataset_paths": + parts.extend([flag, *[str(x) for x in value]]) + elif flag == "--profiling_ranks": + parts.extend([flag, ",".join(str(x) for x in value)]) + else: + parts.extend([flag, str(value[0]) if len(value) == 1 else ",".join(str(x) for x in value)]) else: sv = str(value) if sv != "": @@ -235,8 +244,11 @@ def add_field(field: str, flag: str, value: Any) -> None: add_field("hf_token", "-hf", args.hf_token) add_field("nemo_home", "-nh", args.nemo_home) add_field("wandb_key", "-wdk", args.wandb_key) - add_field("wandb_prj_name", "-wdp", args.wandb_prj_name) - add_field("wandb_exp_name", "-wdj", args.wandb_exp_name) + add_field("wandb_project_name", "-wdp", args.wandb_project_name) + add_field("wandb_entity_name", "-wde", args.wandb_entity_name) + add_field("wandb_experiment_name", "-wdj", args.wandb_experiment_name) + add_field("wandb_save_dir", "-wds", args.wandb_save_dir) + add_field("max_retries", "--max_retries", args.max_retries) if args.dryrun and "dryrun" in fields_set: parts.append("-d") add_field("num_gpus", "-ng", args.num_gpus) @@ -244,15 +256,17 @@ def add_field(field: str, flag: str, value: Any) -> None: if mounts: add("-cm", ",".join(mounts)) - # Model flags (Megatron-Bridge r0.2.0 API) + # Model flags (Megatron-Bridge main-branch API) + if args.use_recipes and "use_recipes" in fields_set: + parts.append("--use_recipes") if "enable_vboost" in fields_set: add_field("enable_vboost", "-vb", bool(args.enable_vboost)) - if not args.model_name: - raise RuntimeError("Missing required cmd_args.model_name (maps to -m/--model_name).") - if not args.model_size: - raise RuntimeError("Missing required cmd_args.model_size (maps to -s/--model_size).") - add_field("model_name", "-m", args.model_name) - add_field("model_size", "-s", args.model_size) + if not args.model_family_name: + raise RuntimeError("Missing required cmd_args.model_family_name (maps to -m/--model_family_name).") + if not args.model_recipe_name: + raise RuntimeError("Missing required cmd_args.model_recipe_name (maps to -mr/--model_recipe_name).") + add_field("model_family_name", "-m", args.model_family_name) + add_field("model_recipe_name", "-mr", args.model_recipe_name) if args.enable_nsys and "enable_nsys" in fields_set: parts.append("-en") add_field("domain", "--domain", args.domain) @@ -260,6 +274,8 @@ def add_field(field: str, flag: str, value: Any) -> None: add_field("use_tokendrop", "--use_tokendrop", bool(args.use_tokendrop)) if "use_megatron_fsdp" in fields_set and args.use_megatron_fsdp is not None: add_field("use_megatron_fsdp", "--use_megatron_fsdp", bool(args.use_megatron_fsdp)) + if "nccl_ub" in fields_set and args.nccl_ub is not None: + add_field("nccl_ub", "--nccl_ub", bool(args.nccl_ub)) add_field("cuda_graph_impl", "--cuda_graph_impl", args.cuda_graph_impl) if args.cuda_graph_scope and "cuda_graph_scope" in fields_set: add_field( @@ -277,6 +293,7 @@ def add_field(field: str, flag: str, value: Any) -> None: # Batch add_field("mb", "-mb", args.mb) add_field("gb", "-gb", args.gb) + add_field("seq_length", "-sl", args.seq_length) # Misc if "moe_a2a_overlap" in fields_set: @@ -286,11 +303,44 @@ def add_field(field: str, flag: str, value: Any) -> None: add_field("activation_offload_layers", "-ol", args.activation_offload_layers) if args.recompute_modules and "recompute_modules" in fields_set: parts.extend(["--recompute_modules", self._normalize_recompute_modules(args.recompute_modules)]) - # r0.2.0 supports `--detach` / `--no-detach` flags (no boolean value) - if args.detach is True and "detach" in fields_set: - parts.append("--detach") - elif args.detach is False and "detach" in fields_set: - parts.append("--no-detach") + if "detach" in fields_set and args.detach is not None: + parts.extend(["--detach", "true" if args.detach else "false"]) + + # Optimizer + add_field("lr", "--lr", args.lr) + add_field("min_lr", "--min_lr", args.min_lr) + add_field("warmup_iters", "--warmup_iters", args.warmup_iters) + + # Checkpointing + add_field("pretrained_checkpoint", "--pretrained_checkpoint", args.pretrained_checkpoint) + add_field("save_dir", "--save_dir", args.save_dir) + add_field("load_dir", "--load_dir", args.load_dir) + add_field("save_interval", "--save_interval", args.save_interval) + add_field("most_recent_k", "--most_recent_k", args.most_recent_k) + add_field("save_config_filepath", "--save_config_filepath", args.save_config_filepath) + + # Data / Tokenizer + add_field("data", "--data", args.data) + add_field("dataset_paths", "--dataset_paths", args.dataset_paths) + add_field("dataset_root", "--dataset_root", args.dataset_root) + add_field("index_mapping_dir", "--index_mapping_dir", args.index_mapping_dir) + add_field("dataset_name", "--dataset_name", args.dataset_name) + if args.packed_sequence and "packed_sequence" in fields_set: + parts.append("--packed_sequence") + if args.head_only and "head_only" in fields_set: + parts.append("--head_only") + add_field("tokenizer_type", "--tokenizer_type", args.tokenizer_type) + add_field("tokenizer_model", "--tokenizer_model", args.tokenizer_model) + add_field("vocab_size", "--vocab_size", args.vocab_size) + + # Profiling (performance group) + add_field("pytorch_profiler", "-pyp", args.pytorch_profiler) + add_field("profiling_start_step", "--profiling_start_step", args.profiling_start_step) + add_field("profiling_stop_step", "--profiling_stop_step", args.profiling_stop_step) + add_field("record_memory_history", "-mh", args.record_memory_history) + if args.profiling_gpu_metrics and "profiling_gpu_metrics" in fields_set: + parts.append("--profiling_gpu_metrics") + add_field("profiling_ranks", "--profiling_ranks", args.profiling_ranks) # Extra user args (dict -> string) if tdef.extra_cmd_args: diff --git a/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py index 3062116c8..48ca34681 100644 --- a/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py @@ -38,8 +38,8 @@ def test_run(self, tmp_path: Path) -> TestRun: args = MegatronBridgeCmdArgs( container_image=str(sqsh), hf_token="dummy_token", - model_name="qwen3", - model_size="30b_a3b", + model_family_name="qwen3", + model_recipe_name="30b_a3b", cuda_graph_scope="[moe_router,moe_preprocess]", compute_dtype="fp8_mx", num_gpus=8, @@ -126,8 +126,8 @@ def test_defaults_not_emitted_when_not_set_in_toml(self, slurm_system: SlurmSyst args = MegatronBridgeCmdArgs( container_image=str(sqsh), hf_token="dummy_token", - model_name="qwen3", - model_size="30b_a3b", + model_family_name="qwen3", + model_recipe_name="30b_a3b", num_gpus=8, gpus_per_node=4, ) @@ -186,11 +186,11 @@ def test_cuda_graph_scope_normalization(self, cmd_gen: MegatronBridgeSlurmComman assert "--cuda_graph_scope moe_router,moe_preprocess" in wrapper_content @pytest.mark.parametrize( - "detach, expected, not_expected", + "detach, expected", [ - (True, "--detach", "--no-detach"), - (False, "--no-detach", "--detach"), - (None, None, "--detach"), + (True, "--detach true"), + (False, "--detach false"), + (None, None), ], ) def test_detach_flags( @@ -199,7 +199,6 @@ def test_detach_flags( test_run: TestRun, detach: bool | None, expected: str | None, - not_expected: str, ) -> None: tdef = cast(MegatronBridgeTestDefinition, test_run.test) @@ -218,11 +217,9 @@ def test_detach_flags( wrapper_content = wrapper.read_text() if detach is None: assert "--detach" not in wrapper_content - assert "--no-detach" not in wrapper_content else: assert expected is not None assert expected in wrapper_content - assert not_expected not in wrapper_content def test_generated_command_file_written( self, cmd_gen: MegatronBridgeSlurmCommandGenStrategy, test_run: TestRun @@ -235,3 +232,72 @@ def test_generated_command_file_written( assert cmd in content assert content.startswith("bash ") assert "cloudai_megatron_bridge_submit_and_parse_jobid.sh" in content + + def test_use_recipes_emitted_only_when_true(self, slurm_system: SlurmSystem, tmp_path: Path) -> None: + sqsh = tmp_path / "img.sqsh" + sqsh.write_text("x") + + base = dict( + container_image=str(sqsh), + hf_token="dummy_token", + model_family_name="qwen3", + model_recipe_name="30b_a3b", + num_gpus=8, + gpus_per_node=4, + ) + + tdef_true = MegatronBridgeTestDefinition( + name="mb", + description="desc", + test_template_name="MegatronBridge", + cmd_args=MegatronBridgeCmdArgs(**base, use_recipes=True), + extra_container_mounts=[], + git_repos=[ + { + "url": "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", + "commit": "r0.2.0", + "mount_as": "/opt/Megatron-Bridge", + } + ], # type: ignore[arg-type] + ) + + (tmp_path / "run_repo").mkdir() + (tmp_path / "run_venv").mkdir() + (tmp_path / "mbridge_repo").mkdir() + tdef_true.python_executable.git_repo.installed_path = tmp_path / "run_repo" + tdef_true.python_executable.venv_path = tmp_path / "run_venv" + tdef_true.megatron_bridge_repo.installed_path = tmp_path / "mbridge_repo" + tdef_true.docker_image.installed_path = tmp_path / "cached.sqsh" + + tr_true = TestRun(test=tdef_true, name="tr", num_nodes=1, nodes=[], output_path=tmp_path / "out_true") + slurm_system.account = "acct" + slurm_system.default_partition = "gb300" + cmd_gen_true = MegatronBridgeSlurmCommandGenStrategy(slurm_system, tr_true) + cmd_gen_true.gen_exec_command() + wrapper_true = tr_true.output_path / "cloudai_megatron_bridge_submit_and_parse_jobid.sh" + assert "--use_recipes" in wrapper_true.read_text() + + tdef_none = MegatronBridgeTestDefinition( + name="mb", + description="desc", + test_template_name="MegatronBridge", + cmd_args=MegatronBridgeCmdArgs(**base), + extra_container_mounts=[], + git_repos=[ + { + "url": "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", + "commit": "r0.2.0", + "mount_as": "/opt/Megatron-Bridge", + } + ], # type: ignore[arg-type] + ) + tdef_none.python_executable.git_repo.installed_path = tmp_path / "run_repo" + tdef_none.python_executable.venv_path = tmp_path / "run_venv" + tdef_none.megatron_bridge_repo.installed_path = tmp_path / "mbridge_repo" + tdef_none.docker_image.installed_path = tmp_path / "cached.sqsh" + + tr_none = TestRun(test=tdef_none, name="tr", num_nodes=1, nodes=[], output_path=tmp_path / "out_none") + cmd_gen_none = MegatronBridgeSlurmCommandGenStrategy(slurm_system, tr_none) + cmd_gen_none.gen_exec_command() + wrapper_none = tr_none.output_path / "cloudai_megatron_bridge_submit_and_parse_jobid.sh" + assert "--use_recipes" not in wrapper_none.read_text() From 357c00a634b63b117209686d564b9ff07926cb0e Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Wed, 7 Jan 2026 15:34:25 -0800 Subject: [PATCH 12/22] fix pyright issue --- ...atron_bridge_slurm_command_gen_strategy.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py index 48ca34681..8ce5ec502 100644 --- a/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py @@ -102,7 +102,7 @@ def test_model_fields_whitespace_only_rejected(self, field_name: str) -> None: MegatronBridgeCmdArgs.model_validate(data) def test_git_repos_can_pin_megatron_bridge_commit(self) -> None: - args = MegatronBridgeCmdArgs(hf_token="dummy_token", model_name="qwen3", model_size="30b_a3b") + args = MegatronBridgeCmdArgs(hf_token="dummy_token", model_family_name="qwen3", model_recipe_name="30b_a3b") tdef = MegatronBridgeTestDefinition( name="mb", description="desc", @@ -237,20 +237,19 @@ def test_use_recipes_emitted_only_when_true(self, slurm_system: SlurmSystem, tmp sqsh = tmp_path / "img.sqsh" sqsh.write_text("x") - base = dict( - container_image=str(sqsh), - hf_token="dummy_token", - model_family_name="qwen3", - model_recipe_name="30b_a3b", - num_gpus=8, - gpus_per_node=4, - ) - tdef_true = MegatronBridgeTestDefinition( name="mb", description="desc", test_template_name="MegatronBridge", - cmd_args=MegatronBridgeCmdArgs(**base, use_recipes=True), + cmd_args=MegatronBridgeCmdArgs( + container_image=str(sqsh), + hf_token="dummy_token", + model_family_name="qwen3", + model_recipe_name="30b_a3b", + num_gpus=8, + gpus_per_node=4, + use_recipes=True, + ), extra_container_mounts=[], git_repos=[ { @@ -281,7 +280,14 @@ def test_use_recipes_emitted_only_when_true(self, slurm_system: SlurmSystem, tmp name="mb", description="desc", test_template_name="MegatronBridge", - cmd_args=MegatronBridgeCmdArgs(**base), + cmd_args=MegatronBridgeCmdArgs( + container_image=str(sqsh), + hf_token="dummy_token", + model_family_name="qwen3", + model_recipe_name="30b_a3b", + num_gpus=8, + gpus_per_node=4, + ), extra_container_mounts=[], git_repos=[ { From 9d3d45467b1a9517eb1b923a166275d633ac5d73 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 12 Jan 2026 15:10:57 -0800 Subject: [PATCH 13/22] remove domain flag --- src/cloudai/workloads/megatron_bridge/megatron_bridge.py | 1 - .../workloads/megatron_bridge/slurm_command_gen_strategy.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index c472d9e93..34dace7d1 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -43,7 +43,6 @@ class MegatronBridgeCmdArgs(CmdArgs): model_family_name: str = Field(default="") model_recipe_name: str = Field(default="") use_recipes: Optional[bool] = Field(default=None) - domain: str = Field(default="llm") task: str = Field(default="pretrain") compute_dtype: str = Field(default="bf16") fp8_recipe: Optional[str] = Field(default=None) diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index 5a896be37..3eaba635c 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -269,7 +269,6 @@ def add_field(field: str, flag: str, value: Any) -> None: add_field("model_recipe_name", "-mr", args.model_recipe_name) if args.enable_nsys and "enable_nsys" in fields_set: parts.append("-en") - add_field("domain", "--domain", args.domain) if "use_tokendrop" in fields_set and args.use_tokendrop is not None: add_field("use_tokendrop", "--use_tokendrop", bool(args.use_tokendrop)) if "use_megatron_fsdp" in fields_set and args.use_megatron_fsdp is not None: From 18fa65354b3db5e2a189b12bd715cbd9c6f534c5 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Tue, 13 Jan 2026 15:43:47 -0800 Subject: [PATCH 14/22] fix the stderr/stdout --- .../workloads/megatron_bridge/slurm_command_gen_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index 3eaba635c..09638685a 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -130,7 +130,7 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str) -> str: "", ': >"$LOG"', "LAUNCH_RC=0", - f'{launcher_cmd} >>"$LOG" 2>&1 || LAUNCH_RC=$?', + f"{launcher_cmd} >>\"$LOG\" 2>&1 || LAUNCH_RC=$?", "", # Parse job id from Megatron-Bridge output (multiple possible formats) "", From c8149748db3b4f3632ccfbcd393fe1ea13e3ead2 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Fri, 16 Jan 2026 18:06:43 -0800 Subject: [PATCH 15/22] make cuda_graph_impl dse'ble flag --- src/cloudai/workloads/megatron_bridge/megatron_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index 34dace7d1..a149f3802 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -60,7 +60,7 @@ class MegatronBridgeCmdArgs(CmdArgs): # Feature flags (allow sweeps) use_tokendrop: Optional[Union[bool, List[bool]]] = Field(default=None) use_megatron_fsdp: Optional[Union[bool, List[bool]]] = Field(default=None) - cuda_graph_impl: Optional[str] = Field(default=None) + cuda_graph_impl: Optional[Union[str, List[str]]] = Field(default=None) cuda_graph_scope: Optional[Union[str, List[str]]] = Field(default=None) # Parallelism From c256e98ec658bbde8b7e3fccd2f7bfbea9324d4a Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Fri, 16 Jan 2026 18:10:35 -0800 Subject: [PATCH 16/22] fix regex --- .../workloads/megatron_bridge/slurm_command_gen_strategy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index 09638685a..fad02839a 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -133,10 +133,10 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str) -> str: f"{launcher_cmd} >>\"$LOG\" 2>&1 || LAUNCH_RC=$?", "", # Parse job id from Megatron-Bridge output (multiple possible formats) + # Patterns: "Job id: 694112", "- Job id: 694112", "Job ID: 694112" "", 'JOB_ID=""', - 'JOB_ID=$(grep -Eio "Job[[:space:]]+id[: ]+[0-9]+" "$LOG" | ' - 'tail -n1 | grep -Eo "[0-9]+" | tail -n1 || true)', + 'JOB_ID=$(grep -Eio "(Job id[: ]+[0-9]+|-[ ]*Job id[: ]+[0-9]+)" "$LOG" | tail -n1 | grep -Eo "[0-9]+" | tail -n1 || true)', "", # Emit a canonical line for CloudAI to parse "", From b8cdf37f812d1ab516c3636253f44fd3e65c7e27 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Wed, 21 Jan 2026 14:48:03 -0800 Subject: [PATCH 17/22] add constraint for moe_overlap --- .../workloads/megatron_bridge/megatron_bridge.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index a149f3802..9993b6b96 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -365,14 +365,16 @@ def _normalize_str_list(val: Optional[Union[str, List[str]]]) -> list[str]: else: constraint10 = True - # Constraint 11: CUDA graphs require a2a overlap disabled + # Constraint 11: When cuda_graph_impl is set (not none), a2a overlap must be disabled + # moe_a2a_overlap can only be true when cuda_graph_impl is 'none' or unset a2a_overlap = _as_bool(self.cmd_args.moe_a2a_overlap) - constraint11 = not (cuda_graphs and a2a_overlap) + cuda_impl_enabled = cgi not in {"", "none", "null"} + constraint11 = not (cuda_impl_enabled and a2a_overlap) if not constraint11: logging.error( - "Constraint 11 failed: cuda_graphs=true requires moe_a2a_overlap=false. " - "cuda_graphs=%s moe_a2a_overlap=%s", - cuda_graphs, + "Constraint 11 failed: moe_a2a_overlap must be false when cuda_graph_impl is not 'none'. " + "cuda_graph_impl=%s moe_a2a_overlap=%s", + cgi, a2a_overlap, ) From 89e11590fd833495fb293b8e370a2d16e07a6cd8 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 26 Jan 2026 17:43:52 -0800 Subject: [PATCH 18/22] add vp/pp constrint check for deepseek --- .../megatron_bridge/megatron_bridge.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index 9993b6b96..036839f4b 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -476,6 +476,27 @@ def _normalize_str_list(val: Optional[Union[str, List[str]]]) -> list[str]: else: constraint17 = True + # Constraint 18: Valid (PP, VP) combinations for DeepSeek v3 pipeline layout + # Only specific (pp, vp) pairs are supported by DeepSeek v3's pipeline layout mapping + model_recipe = (self.cmd_args.model_recipe_name or "").lower() + is_deepseek_v3 = "deepseek_v3" in model_recipe or "deepseekv3" in model_recipe + + if is_deepseek_v3: + valid_pp_vp_combinations = {(1, 1), (4, 1), (8, 1), (4, 2), (16, 1), (8, 2), (4, 4)} + current_vp = vp if vp is not None else 1 + pp_vp_pair = (pp, current_vp) + constraint18 = pp_vp_pair in valid_pp_vp_combinations + if not constraint18: + logging.error( + "Constraint 18 failed: Invalid (PP, VP) combination for DeepSeek v3. pp=%s vp=%s. " + "Valid combinations: %s", + pp, + current_vp, + sorted(valid_pp_vp_combinations), + ) + else: + constraint18 = True # Skip this constraint for non-DeepSeek v3 models + return bool( constraint1 and constraint2 @@ -494,4 +515,5 @@ def _normalize_str_list(val: Optional[Union[str, List[str]]]) -> list[str]: and constraint15 and constraint16 and constraint17 + and constraint18 ) From 37053e5c09f4e1803b12b91184b6896faf0dcd9b Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 12 Feb 2026 13:37:10 -0800 Subject: [PATCH 19/22] fix the new flags --- src/cloudai/workloads/megatron_bridge/megatron_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index 036839f4b..90e89ac9c 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -130,7 +130,7 @@ def validate_hf_token(cls, v: Optional[str]) -> Optional[str]: raise ValueError("cmd_args.hf_token is required. Please set it to your literal HF token string.") return token - @field_validator("model_name", "model_size", mode="after") + @field_validator("model_family_name", "model_recipe_name", mode="after") @classmethod def validate_model_fields(cls, v: str, info: ValidationInfo) -> str: s = v.strip() From a7172cecca5a9a36570f34cdaa8e68c484d72e1d Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 12 Feb 2026 13:43:39 -0800 Subject: [PATCH 20/22] fix --- src/cloudai/workloads/megatron_bridge/megatron_bridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index 90e89ac9c..df8af66d8 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -122,7 +122,7 @@ class MegatronBridgeCmdArgs(CmdArgs): # Optional distributed optimizer instances (for constraints/divisor) num_distributed_optimizer_instances: Optional[int] = Field(default=None) - @field_validator("hf_token", mode="after") + @field_validator("hf_token", mode="after", check_fields=False) @classmethod def validate_hf_token(cls, v: Optional[str]) -> Optional[str]: token = (v or "").strip() @@ -130,7 +130,7 @@ def validate_hf_token(cls, v: Optional[str]) -> Optional[str]: raise ValueError("cmd_args.hf_token is required. Please set it to your literal HF token string.") return token - @field_validator("model_family_name", "model_recipe_name", mode="after") + @field_validator("model_family_name", "model_recipe_name", mode="after", check_fields=False) @classmethod def validate_model_fields(cls, v: str, info: ValidationInfo) -> str: s = v.strip() From a5a13990112fd0f2e98b1399cf1faccf8e9454f3 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 12 Feb 2026 13:56:05 -0800 Subject: [PATCH 21/22] more fixes --- .../megatron_bridge/megatron_bridge.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index df8af66d8..8c799bf35 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -17,7 +17,7 @@ import logging from typing import List, Optional, Union, cast -from pydantic import Field, ValidationInfo, field_validator +from pydantic import Field, field_validator, model_validator from cloudai.core import DockerImage, GitRepo, Installable, PythonExecutable from cloudai.models.workload import CmdArgs, TestDefinition @@ -130,13 +130,16 @@ def validate_hf_token(cls, v: Optional[str]) -> Optional[str]: raise ValueError("cmd_args.hf_token is required. Please set it to your literal HF token string.") return token - @field_validator("model_family_name", "model_recipe_name", mode="after", check_fields=False) - @classmethod - def validate_model_fields(cls, v: str, info: ValidationInfo) -> str: - s = v.strip() - if not s: - raise ValueError(f"cmd_args.{info.field_name} cannot be empty.") - return s + @model_validator(mode="after") + def validate_model_fields_non_empty(self) -> "MegatronBridgeCmdArgs": + """Ensure model_family_name and model_recipe_name are non-empty and stripped.""" + for name in ("model_family_name", "model_recipe_name"): + val = getattr(self, name, "") or "" + s = str(val).strip() + if not s: + raise ValueError(f"cmd_args.{name} cannot be empty.") + setattr(self, name, s) + return self class MegatronBridgeTestDefinition(TestDefinition): From 69bc44de2ae8c706443697d3cc9ab94cdc1ee612 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Fri, 13 Feb 2026 12:34:12 -0800 Subject: [PATCH 22/22] remove -cm and update all flags in m-bridge --- .../megatron_bridge/megatron_bridge.py | 36 ++++++++++++++++++- .../slurm_command_gen_strategy.py | 31 ++++++++++++++-- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index 8c799bf35..dc2c5e1c8 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -33,12 +33,28 @@ class MegatronBridgeCmdArgs(CmdArgs): container_image: str = Field(default="") num_gpus: int = Field(default=8) gpus_per_node: int = Field(default=8) - custom_mounts: Optional[str] = Field(default=None) + custom_mounts: Optional[Union[str, List[str]]] = Field( + default=None, + description="Comma-separated or list of host_path:container_path mounts; merged with test-level extra_container_mounts for -cm.", + ) enable_vboost: Optional[bool] = Field(default=False) dryrun: Optional[bool] = Field(default=False) enable_nsys: Optional[bool] = Field(default=False) detach: Optional[bool] = Field(default=None) + # Domain / model overrides (argument_parser main + model) + domain: Optional[str] = Field( + default=None, + description="Domain: llm, vlm, or qwen3vl (default llm).", + ) + hidden_size: Optional[int] = Field(default=None, description="Override hidden size for experiment.") + num_layers: Optional[int] = Field(default=None, description="Override number of layers.") + pipeline_model_parallel_layout: Optional[str] = Field(default=None) + first_k_dense_replace: Optional[int] = Field( + default=None, + description="Number of MoE layers to convert to dense.", + ) + # Model/task model_family_name: str = Field(default="") model_recipe_name: str = Field(default="") @@ -108,6 +124,14 @@ class MegatronBridgeCmdArgs(CmdArgs): record_memory_history: Optional[bool] = Field(default=None) profiling_gpu_metrics: Optional[bool] = Field(default=None) profiling_ranks: Optional[Union[int, List[int]]] = Field(default=None) + nsys_trace: Optional[Union[str, List[str]]] = Field( + default=None, + description="Comma-separated nsys trace events (e.g. cuda,nvtx).", + ) + nsys_extra_args: Optional[Union[str, List[str]]] = Field( + default=None, + description="Comma-separated extra arguments for nsys.", + ) # Performance nccl_ub: Optional[Union[bool, List[bool]]] = Field(default=None) @@ -122,6 +146,16 @@ class MegatronBridgeCmdArgs(CmdArgs): # Optional distributed optimizer instances (for constraints/divisor) num_distributed_optimizer_instances: Optional[int] = Field(default=None) + # Config variant (argument_parser config_variant group) + config_variant: Optional[str] = Field( + default=None, + description="Config variant (e.g. v1, v2). Launcher default is v2.", + ) + list_config_variants: Optional[bool] = Field( + default=None, + description="If true, list config variants and interactively select (--list_config_variants).", + ) + @field_validator("hf_token", mode="after", check_fields=False) @classmethod def validate_hf_token(cls, v: Optional[str]) -> Optional[str]: diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index fad02839a..18df61685 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -19,7 +19,7 @@ import logging import stat from pathlib import Path -from typing import Any, cast +from typing import Any, Optional, cast import toml @@ -162,6 +162,16 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str) -> str: return f"bash {wrapper_path}" + def _list_or_comma_str(self, val: Any) -> Optional[str]: + """Normalize list or comma-separated string; return None if empty or val is None.""" + if val is None: + return None + if isinstance(val, str): + s = val.strip() + else: + s = ",".join(str(x) for x in val).strip() + return s if s else None + def _build_launcher_parts( # noqa: C901 self, args: MegatronBridgeCmdArgs, tdef: MegatronBridgeTestDefinition, repo_path: Path, launcher_py: Path ) -> list[str]: @@ -194,8 +204,13 @@ def _installed_container_path() -> str: else: container_path = _installed_container_path() + # Combine cmd_args custom_mounts with test-level extra_container_mounts; only pass -cm when non-empty mounts: list[str] = [] - mounts.append(f"{repo_path.absolute()}:/opt/Megatron-Bridge") + if args.custom_mounts is not None: + if isinstance(args.custom_mounts, str): + mounts.extend(m.strip() for m in args.custom_mounts.split(",") if m.strip()) + else: + mounts.extend(str(m).strip() for m in args.custom_mounts if str(m).strip()) mounts.extend(tdef.extra_container_mounts or []) venv_path = tdef.python_executable.venv_path or (self.system.install_path / tdef.python_executable.venv_name) @@ -257,6 +272,7 @@ def add_field(field: str, flag: str, value: Any) -> None: add("-cm", ",".join(mounts)) # Model flags (Megatron-Bridge main-branch API) + add_field("domain", "--domain", args.domain) if args.use_recipes and "use_recipes" in fields_set: parts.append("--use_recipes") if "enable_vboost" in fields_set: @@ -267,6 +283,10 @@ def add_field(field: str, flag: str, value: Any) -> None: raise RuntimeError("Missing required cmd_args.model_recipe_name (maps to -mr/--model_recipe_name).") add_field("model_family_name", "-m", args.model_family_name) add_field("model_recipe_name", "-mr", args.model_recipe_name) + add_field("hidden_size", "--hidden_size", args.hidden_size) + add_field("num_layers", "--num_layers", args.num_layers) + add_field("pipeline_model_parallel_layout", "--pipeline_model_parallel_layout", args.pipeline_model_parallel_layout) + add_field("first_k_dense_replace", "--first_k_dense_replace", args.first_k_dense_replace) if args.enable_nsys and "enable_nsys" in fields_set: parts.append("-en") if "use_tokendrop" in fields_set and args.use_tokendrop is not None: @@ -340,6 +360,13 @@ def add_field(field: str, flag: str, value: Any) -> None: if args.profiling_gpu_metrics and "profiling_gpu_metrics" in fields_set: parts.append("--profiling_gpu_metrics") add_field("profiling_ranks", "--profiling_ranks", args.profiling_ranks) + add_field("nsys_trace", "--nsys_trace", self._list_or_comma_str(args.nsys_trace)) + add_field("nsys_extra_args", "--nsys_extra_args", self._list_or_comma_str(args.nsys_extra_args)) + + # Config variant + add_field("config_variant", "-cv", args.config_variant) + if args.list_config_variants and "list_config_variants" in fields_set: + parts.append("--list_config_variants") # Extra user args (dict -> string) if tdef.extra_cmd_args: