From 9fc6cb765a6f5dfc92dfd1a2782fe59eb70a037d Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 19 Feb 2026 13:00:37 -0800 Subject: [PATCH 1/6] Rename max_parallelism to max_concurrency in GenerationStep and dispatch_utils Summary: Renames the `max_parallelism` parameter to `max_concurrency` across GenerationStep, GenerationNode, and the generation strategy dispatch utilities. Adds backward-compatible deprecated `max_parallelism` parameters with deprecation warnings where the public API is affected (`choose_generation_strategy`). Internal variable names (`sobol_parallelism`, `bo_parallelism`) are renamed to `sobol_concurrency`, `bo_concurrency` for consistency. Differential Revision: D92457714 --- ax/generation_strategy/dispatch_utils.py | 131 +++++++++++------- ax/generation_strategy/generation_node.py | 32 +++-- .../tests/test_dispatch_utils.py | 68 ++++----- 3 files changed, 140 insertions(+), 91 deletions(-) diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index f2721d47fc6..52326c66510 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -7,6 +7,7 @@ # pyre-strict import logging +import warnings from math import ceil from typing import Any @@ -16,6 +17,7 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, @@ -31,7 +33,7 @@ logger: logging.Logger = get_logger(__name__) -DEFAULT_BAYESIAN_PARALLELISM = 3 +DEFAULT_BAYESIAN_CONCURRENCY = 3 # `BO_MIXED` optimizes all range parameters once for each combination of choice # parameters, then takes the optimum of those optima. The cost associated with this # method grows with the number of combinations, and so it is only used when the @@ -49,7 +51,7 @@ def _make_sobol_step( num_trials: int = -1, min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: int | None = None, + max_concurrency: int | None = None, seed: int | None = None, should_deduplicate: bool = False, ) -> GenerationStep: @@ -62,7 +64,7 @@ def _make_sobol_step( ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed ), enforce_num_trials=enforce_num_trials, - max_parallelism=max_parallelism, + max_concurrency=max_concurrency, generator_kwargs={"deduplicate": True, "seed": seed}, should_deduplicate=should_deduplicate, use_all_trials_in_exp=True, @@ -73,7 +75,7 @@ def _make_botorch_step( num_trials: int = -1, min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: int | None = None, + max_concurrency: int | None = None, generator: GeneratorRegistryBase = Generators.BOTORCH_MODULAR, generator_kwargs: dict[str, Any] | None = None, winsorization_config: None @@ -130,7 +132,7 @@ def _make_botorch_step( ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed ), enforce_num_trials=enforce_num_trials, - max_parallelism=max_parallelism, + max_concurrency=max_concurrency, generator_kwargs=generator_kwargs, should_deduplicate=should_deduplicate, ) @@ -300,8 +302,8 @@ def choose_generation_strategy_legacy( num_completed_initialization_trials: int = 0, max_initialization_trials: int | None = None, min_sobol_trials_observed: int | None = None, - max_parallelism_cap: int | None = None, - max_parallelism_override: int | None = None, + max_concurrency_cap: int | None = None, + max_concurrency_override: int | None = None, optimization_config: OptimizationConfig | None = None, should_deduplicate: bool = False, use_saasbo: bool = False, @@ -311,6 +313,9 @@ def choose_generation_strategy_legacy( suggested_model_override: GeneratorRegistryBase | None = None, use_input_warping: bool = False, simplify_parameter_changes: bool = False, + # Deprecated arguments for backwards compatibility. + max_parallelism_cap: int | None = None, + max_parallelism_override: int | None = None, ) -> GenerationStrategy: """Select an appropriate generation strategy based on the properties of the search space and expected settings of the experiment, such as number of @@ -325,11 +330,11 @@ def choose_generation_strategy_legacy( enforce_sequential_optimization: Whether to enforce that 1) the generation strategy needs to be updated with ``min_trials_observed`` observations for a given generation step before proceeding to the next one and 2) maximum - number of trials running at once (max_parallelism) if enforced for the - BayesOpt step. NOTE: ``max_parallelism_override`` and - ``max_parallelism_cap`` settings will still take their effect on max - parallelism even if ``enforce_sequential_optimization=False``, so if those - settings are specified, max parallelism will be enforced. + number of trials running at once (max_concurrency) if enforced for the + BayesOpt step. NOTE: ``max_concurrency_override`` and + ``max_concurrency_cap`` settings will still take their effect on max + concurrency even if ``enforce_sequential_optimization=False``, so if those + settings are specified, max concurrency will be enforced. random_seed: Fixed random seed for the Sobol generator. torch_device: The device to use for generation steps implemented in PyTorch (e.g. via BoTorch). Some generation steps (in particular EHVI-based ones @@ -360,21 +365,21 @@ def choose_generation_strategy_legacy( min_sobol_trials_observed: Minimum number of Sobol trials that must be observed before proceeding to the next generation step. Defaults to `ceil(num_initialization_trials / 2)`. - max_parallelism_cap: Integer cap on parallelism in this generation strategy. - If specified, ``max_parallelism`` setting in each generation step will be + max_concurrency_cap: Integer cap on concurrency in this generation strategy. + If specified, ``max_concurrency`` setting in each generation step will be set to the minimum of the default setting for that step and the value of - this cap. ``max_parallelism_cap`` is meant to just be a hard limit on - parallelism (e.g. to avoid overloading machine(s) that evaluate the + this cap. ``max_concurrency_cap`` is meant to just be a hard limit on + concurrency (e.g. to avoid overloading machine(s) that evaluate the experiment trials). Specify only if not specifying - ``max_parallelism_override``. - max_parallelism_override: Integer, with which to override the default max - parallelism setting for all steps in the generation strategy returned from - this function. Each generation step has a ``max_parallelism`` value, which + ``max_concurrency_override``. + max_concurrency_override: Integer, with which to override the default max + concurrency setting for all steps in the generation strategy returned from + this function. Each generation step has a ``max_concurrency`` value, which restricts how many trials can run simultaneously during a given generation - step. By default, the parallelism setting is chosen as appropriate for the - model in a given generation step. If ``max_parallelism_override`` is -1, - no max parallelism will be enforced for any step of the generation - strategy. Be aware that parallelism is limited to improve performance of + step. By default, the concurrency setting is chosen as appropriate for the + model in a given generation step. If ``max_concurrency_override`` is -1, + no max concurrency will be enforced for any step of the generation + strategy. Be aware that concurrency is limited to improve performance of Bayesian optimization, so only disable its limiting if necessary. optimization_config: Used to infer whether to use MOO. should_deduplicate: Whether to deduplicate the parameters of proposed arms @@ -407,6 +412,34 @@ def choose_generation_strategy_legacy( simplify parameter changes in arms generated via Bayesian Optimization by pruning irrelevant parameter changes. """ + # Handle deprecated arguments. + if max_parallelism_cap is not None: + warnings.warn( + "`max_parallelism_cap` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency_cap` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency_cap is not None: + raise UserInputError( + "Cannot specify both `max_parallelism_cap` and `max_concurrency_cap`." + ) + max_concurrency_cap = max_parallelism_cap + + if max_parallelism_override is not None: + warnings.warn( + "`max_parallelism_override` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency_override` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency_override is not None: + raise UserInputError( + "Cannot specify both `max_parallelism_override` and " + "`max_concurrency_override`." + ) + max_concurrency_override = max_parallelism_override + if experiment is not None and optimization_config is None: optimization_config = experiment.optimization_config @@ -416,36 +449,36 @@ def choose_generation_strategy_legacy( optimization_config=optimization_config, use_saasbo=use_saasbo, ) - # Determine max parallelism for the generation steps. - if max_parallelism_override == -1: - # `max_parallelism_override` of -1 means no max parallelism enforcement in - # the generation strategy, which means `max_parallelism=None` in gen. steps. - sobol_parallelism = bo_parallelism = None - elif max_parallelism_override is not None: - sobol_parallelism = bo_parallelism = max_parallelism_override - elif max_parallelism_cap is not None: # Max parallelism override is None by now - sobol_parallelism = max_parallelism_cap - bo_parallelism = min(max_parallelism_cap, DEFAULT_BAYESIAN_PARALLELISM) + # Determine max concurrency for the generation steps. + if max_concurrency_override == -1: + # `max_concurrency_override` of -1 means no max concurrency enforcement in + # the generation strategy, which means `max_concurrency=None` in gen. steps. + sobol_concurrency = bo_concurrency = None + elif max_concurrency_override is not None: + sobol_concurrency = bo_concurrency = max_concurrency_override + elif max_concurrency_cap is not None: # Max concurrency override is None by now + sobol_concurrency = max_concurrency_cap + bo_concurrency = min(max_concurrency_cap, DEFAULT_BAYESIAN_CONCURRENCY) elif not enforce_sequential_optimization: - # If no max parallelism settings specified and not enforcing sequential - # optimization, do not limit parallelism. - sobol_parallelism = bo_parallelism = None - else: # No additional max parallelism settings, use defaults - sobol_parallelism = None # No restriction on Sobol phase - bo_parallelism = DEFAULT_BAYESIAN_PARALLELISM + # If no max concurrency settings specified and not enforcing sequential + # optimization, do not limit concurrency. + sobol_concurrency = bo_concurrency = None + else: # No additional max concurrency settings, use defaults + sobol_concurrency = None # No restriction on Sobol phase + bo_concurrency = DEFAULT_BAYESIAN_CONCURRENCY if not force_random_search and suggested_model is not None: if not enforce_sequential_optimization and ( - max_parallelism_override is not None or max_parallelism_cap is not None + max_concurrency_override is not None or max_concurrency_cap is not None ): logger.info( - "If `enforce_sequential_optimization` is False, max parallelism is " - "not enforced and other max parallelism settings will be ignored." + "If `enforce_sequential_optimization` is False, max concurrency is " + "not enforced and other max concurrency settings will be ignored." ) - if max_parallelism_override is not None and max_parallelism_cap is not None: + if max_concurrency_override is not None and max_concurrency_cap is not None: raise ValueError( - "If `max_parallelism_override` specified, cannot also apply " - "`max_parallelism_cap`." + "If `max_concurrency_override` specified, cannot also apply " + "`max_concurrency_cap`." ) # If number of initialization trials is not specified, estimate it. @@ -503,7 +536,7 @@ def choose_generation_strategy_legacy( min_trials_observed=min_sobol_trials_observed, enforce_num_trials=enforce_sequential_optimization, seed=random_seed, - max_parallelism=sobol_parallelism, + max_concurrency=sobol_concurrency, should_deduplicate=should_deduplicate, ) ) @@ -512,7 +545,7 @@ def choose_generation_strategy_legacy( generator=suggested_model, winsorization_config=winsorization_config, derelativize_with_raw_status_quo=derelativize_with_raw_status_quo, - max_parallelism=bo_parallelism, + max_concurrency=bo_concurrency, generator_kwargs=generator_kwargs, should_deduplicate=should_deduplicate, disable_progbar=disable_progbar, @@ -544,7 +577,7 @@ def choose_generation_strategy_legacy( _make_sobol_step( seed=random_seed, should_deduplicate=should_deduplicate, - max_parallelism=sobol_parallelism, + max_concurrency=sobol_concurrency, ) ] ) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 435bccc9f8d..428e36560a5 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -8,6 +8,7 @@ from __future__ import annotations +import warnings from collections import defaultdict from collections.abc import Sequence from logging import Logger @@ -971,9 +972,9 @@ class GenerationStep: If `num_trials` of a given step have been generated but `min_trials_ observed` have not been completed, a call to `generation_strategy.gen` will fail with a `DataRequiredError`. - max_parallelism: How many trials generated in the course of this step are + max_concurrency: How many trials generated in the course of this step are allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously. - If `max_parallelism` trials from this step are already running, a call + If `max_concurrency` trials from this step are already running, a call to `generation_strategy.gen` will fail with a `MaxParallelismReached Exception`, indicating that more trials need to be completed before generating and running next trials. @@ -1026,7 +1027,7 @@ def __new__( generator_kwargs: dict[str, Any] | None = None, generator_gen_kwargs: dict[str, Any] | None = None, min_trials_observed: int = 0, - max_parallelism: int | None = None, + max_concurrency: int | None = None, enforce_num_trials: bool = True, should_deduplicate: bool = False, generator_name: str | None = None, @@ -1037,6 +1038,7 @@ def __new__( # Deprecated arguments for backwards compatibility. model_kwargs: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, + max_parallelism: int | None = None, # DEPRECATED: use max_concurrency. ) -> GenerationNode: r"""Creates a ``GenerationNode`` configured as a single-model generation step. @@ -1048,15 +1050,29 @@ def __new__( if use_update: raise DeprecationWarning("`GenerationStep.use_update` is deprecated.") + # Handle deprecated `max_parallelism` argument. + if max_parallelism is not None: + warnings.warn( + "`max_parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency is not None: + raise UserInputError( + "Cannot specify both `max_parallelism` and `max_concurrency`." + ) + max_concurrency = max_parallelism + if num_trials < 1 and num_trials != -1: raise UserInputError( "`num_trials` must be positive or -1 (indicating unlimited) " "for all generation steps." ) - if max_parallelism is not None and max_parallelism < 1: + if max_concurrency is not None and max_concurrency < 1: raise UserInputError( - "Maximum parallelism should be None (if no limit) or " - f"a positive number. Got: {max_parallelism} for " + "Maximum concurrency should be None (if no limit) or " + f"a positive number. Got: {max_concurrency} for " f"step {generator_name}." ) @@ -1130,10 +1146,10 @@ def __new__( use_all_trials_in_exp=use_all_trials_in_exp, ) ) - if max_parallelism is not None: + if max_concurrency is not None: transition_criteria.append( MaxGenerationParallelism( - threshold=max_parallelism, + threshold=max_concurrency, transition_to=placeholder_transition_to, only_in_statuses=[TrialStatus.RUNNING], block_gen_if_met=True, diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index c382dd9fe3a..1b9b0803fbf 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -19,7 +19,7 @@ _make_botorch_step, calculate_num_initialization_trials, choose_generation_strategy_legacy, - DEFAULT_BAYESIAN_PARALLELISM, + DEFAULT_BAYESIAN_CONCURRENCY, ) from ax.generation_strategy.generation_node import GenerationNode from ax.generation_strategy.transition_criterion import ( @@ -621,14 +621,14 @@ def test_enforce_sequential_optimization(self) -> None: sobol_gpei._nodes[0].transition_criteria[0], MinTrials ) self.assertTrue(node0_min_trials.block_gen_if_met) - # Check that max_parallelism is set by verifying MaxGenerationParallelism + # Check that max_concurrency is set by verifying MaxGenerationParallelism # criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in sobol_gpei._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertTrue(len(node1_max_parallelism) > 0) + self.assertTrue(len(node1_max_concurrency) > 0) with self.subTest("False"): sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), @@ -646,22 +646,22 @@ def test_enforce_sequential_optimization(self) -> None: sobol_gpei._nodes[0].transition_criteria[0], MinTrials ) self.assertFalse(node0_min_trials.block_gen_if_met) - # Check that max_parallelism is None by verifying no + # Check that max_concurrency is None by verifying no # MaxGenerationParallelism criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in sobol_gpei._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertEqual(len(node1_max_parallelism), 0) - with self.subTest("False and max_parallelism_override"): + self.assertEqual(len(node1_max_concurrency), 0) + with self.subTest("False and max_concurrency_override"): with self.assertLogs( choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_override=5, + max_concurrency_override=5, ) self.assertTrue( any( @@ -670,14 +670,14 @@ def test_enforce_sequential_optimization(self) -> None: ), logger.output, ) - with self.subTest("False and max_parallelism_cap"): + with self.subTest("False and max_concurrency_cap"): with self.assertLogs( choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_cap=5, + max_concurrency_cap=5, ) self.assertTrue( any( @@ -686,27 +686,27 @@ def test_enforce_sequential_optimization(self) -> None: ), logger.output, ) - with self.subTest("False and max_parallelism_override and max_parallelism_cap"): + with self.subTest("False and max_concurrency_override and max_concurrency_cap"): with self.assertRaisesRegex( ValueError, ( - "If `max_parallelism_override` specified, cannot also apply " - "`max_parallelism_cap`." + "If `max_concurrency_override` specified, cannot also apply " + "`max_concurrency_cap`." ), ): choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_override=5, - max_parallelism_cap=5, + max_concurrency_override=5, + max_concurrency_cap=5, ) - def test_max_parallelism_override(self) -> None: + def test_max_concurrency_override(self) -> None: sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=10 + search_space=get_branin_search_space(), max_concurrency_override=10 ) self.assertTrue( - all(self._get_max_parallelism(s) == 10 for s in sobol_gpei._nodes) + all(self._get_max_concurrency(s) == 10 for s in sobol_gpei._nodes) ) def test_winsorization(self) -> None: @@ -817,47 +817,47 @@ def test_fixed_num_initialization_trials(self) -> None: 3, ) - def _get_max_parallelism(self, node: GenerationNode) -> int | None: - """Helper to extract max_parallelism from transition criteria.""" + def _get_max_concurrency(self, node: GenerationNode) -> int | None: + """Helper to extract max_concurrency from transition criteria.""" for tc in node.transition_criteria: if isinstance(tc, MaxGenerationParallelism): return tc.threshold return None - def test_max_parallelism_adjustments(self) -> None: + def test_max_concurrency_adjustments(self) -> None: # No adjustment. sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[0])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[0])) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[1]), - DEFAULT_BAYESIAN_PARALLELISM, + self._get_max_concurrency(sobol_gpei._nodes[1]), + DEFAULT_BAYESIAN_CONCURRENCY, ) # Impose a cap of 1 on max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_cap=1 + search_space=get_branin_search_space(), max_concurrency_cap=1 ) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[0]), + self._get_max_concurrency(sobol_gpei._nodes[0]), 1, ) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[1]), + self._get_max_concurrency(sobol_gpei._nodes[1]), 1, ) # Disable enforcing max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=-1 + search_space=get_branin_search_space(), max_concurrency_override=-1 ) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[0])) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[1])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[0])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[1])) # Override max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=10 + search_space=get_branin_search_space(), max_concurrency_override=10 ) - self.assertEqual(self._get_max_parallelism(sobol_gpei._nodes[0]), 10) - self.assertEqual(self._get_max_parallelism(sobol_gpei._nodes[1]), 10) + self.assertEqual(self._get_max_concurrency(sobol_gpei._nodes[0]), 10) + self.assertEqual(self._get_max_concurrency(sobol_gpei._nodes[1]), 10) def test_set_should_deduplicate(self) -> None: sobol_gpei = choose_generation_strategy_legacy( From f13083dff2bf8386f088ee42109917161330fa40 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 19 Feb 2026 13:01:55 -0800 Subject: [PATCH 2/6] Rename parallelism to concurrency in Client and AxClient APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Renames the `parallelism` parameter to `concurrency` in `Client.run_trials()` and adds backward-compatible deprecated `max_parallelism` parameters in `AxClient.create_experiment()` and `AxClient.get_max_parallelism()` → `get_max_concurrency()`. Both include deprecation warnings guiding callers to use the new parameter names, with validation that old and new parameters are not specified simultaneously. Differential Revision: D93771849 --- ax/api/client.py | 22 +++++++++++++-- ax/service/ax_client.py | 44 ++++++++++++++++++------------ ax/service/tests/test_ax_client.py | 18 ++++++------ 3 files changed, 55 insertions(+), 29 deletions(-) diff --git a/ax/api/client.py b/ax/api/client.py index 8b1cf583fe8..b5e2c1ce488 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -6,6 +6,7 @@ # pyre-strict import json +import warnings from collections.abc import Iterable, Sequence from logging import Logger from typing import Any, Literal, Self @@ -43,7 +44,7 @@ BaseEarlyStoppingStrategy, PercentileEarlyStoppingStrategy, ) -from ax.exceptions.core import ObjectNotFoundError, UnsupportedError +from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions from ax.service.utils.best_point_mixin import BestPointMixin @@ -710,9 +711,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None: def run_trials( self, max_trials: int, - parallelism: int = 1, + concurrency: int = 1, tolerated_trial_failure_rate: float = 0.5, initial_seconds_between_polls: int = 1, + # Deprecated argument for backwards compatibility. + parallelism: int | None = None, ) -> None: """ Run maximum_trials trials in a loop by creating an ephemeral Orchestrator under @@ -721,12 +724,25 @@ def run_trials( Saves to database on completion if ``storage_config`` is present. """ + # Handle deprecated `parallelism` argument. + if parallelism is not None: + warnings.warn( + "`parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + if concurrency != 1: + raise UserInputError( + "Cannot specify both `parallelism` and `concurrency`." + ) + concurrency = parallelism orchestrator = Orchestrator( experiment=self._experiment, generation_strategy=self._generation_strategy_or_choose(), options=OrchestratorOptions( - max_pending_trials=parallelism, + max_pending_trials=concurrency, tolerated_trial_failure_rate=tolerated_trial_failure_rate, init_seconds_between_polls=initial_seconds_between_polls, ), diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index cb49ed1e39c..df9ef38d9bb 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -836,39 +836,39 @@ def get_trials_data_frame(self) -> pd.DataFrame: """ return self.experiment.to_df() - def get_max_parallelism(self) -> list[tuple[int, int]]: - """Retrieves maximum number of trials that can be scheduled in parallel + def get_max_concurrency(self) -> list[tuple[int, int]]: + """Retrieves maximum number of trials that can be scheduled concurrently at different stages of optimization. Some optimization algorithms profit significantly from sequential optimization (i.e. suggest a few points, get updated with data for them, repeat, see https://ax.dev/docs/bayesopt.html). - Parallelism setting indicates how many trials should be running simulteneously + Concurrency setting indicates how many trials should be running simultaneously (generated, but not yet completed with data). The output of this method is mapping of form - {num_trials -> max_parallelism_setting}, where the max_parallelism_setting - is used for num_trials trials. If max_parallelism_setting is -1, as - many of the trials can be ran in parallel, as necessary. If num_trials - in a tuple is -1, then the corresponding max_parallelism_setting + {num_trials -> max_concurrency_setting}, where the max_concurrency_setting + is used for num_trials trials. If max_concurrency_setting is -1, as + many of the trials can be ran concurrently, as necessary. If num_trials + in a tuple is -1, then the corresponding max_concurrency_setting should be used for all subsequent trials. For example, if the returned list is [(5, -1), (12, 6), (-1, 3)], - the schedule could be: run 5 trials with any parallelism, run 6 trials in - parallel twice, run 3 trials in parallel for as long as needed. Here, + the schedule could be: run 5 trials with any concurrency, run 6 trials + concurrently twice, run 3 trials concurrently for as long as needed. Here, 'running' a trial means obtaining a next trial from `AxClient` through get_next_trials and completing it with data when available. Returns: - Mapping of form {num_trials -> max_parallelism_setting}. + Mapping of form {num_trials -> max_concurrency_setting}. """ - parallelism_settings = [] + concurrency_settings = [] for node in self.generation_strategy._nodes: - # Extract max_parallelism from MaxGenerationParallelism criterion - max_parallelism = None + # Extract max_concurrency from MaxGenerationParallelism criterion + max_concurrency = None for tc in node.transition_criteria: if isinstance(tc, MaxGenerationParallelism): - max_parallelism = tc.threshold + max_concurrency = tc.threshold break # Try to get num_trials from the node. If there's no MinTrials # criterion (unlimited trials), num_trials will raise UserInputError. @@ -877,13 +877,23 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: num_trials = node.num_trials except UserInputError: num_trials = -1 - parallelism_settings.append( + concurrency_settings.append( ( num_trials, - max_parallelism if max_parallelism is not None else num_trials, + max_concurrency if max_concurrency is not None else num_trials, ) ) - return parallelism_settings + return concurrency_settings + + def get_max_parallelism(self) -> list[tuple[int, int]]: + """Deprecated. Use `get_max_concurrency` instead.""" + warnings.warn( + "`get_max_parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `get_max_concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.get_max_concurrency() def get_optimization_trace( self, objective_optimum: float | None = None diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index afe81f05d01..105bc15f0ea 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -50,7 +50,7 @@ UserInputError, ) from ax.exceptions.generation_strategy import MaxParallelismReachedException -from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM +from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_CONCURRENCY from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, @@ -511,7 +511,7 @@ def test_default_generation_strategy_continuous(self) -> None: if i < 5: self.assertEqual(gen_limit, 5 - i) else: - self.assertEqual(gen_limit, DEFAULT_BAYESIAN_PARALLELISM) + self.assertEqual(gen_limit, DEFAULT_BAYESIAN_CONCURRENCY) parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( @@ -1616,14 +1616,14 @@ def test_keep_generating_without_data(self) -> None: self.assertTrue(len(node0_min_trials) > 0) self.assertFalse(node0_min_trials[0].block_gen_if_met) - # Check that max_parallelism is None by verifying no MaxGenerationParallelism + # Check that max_concurrency is None by verifying no MaxGenerationParallelism # criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in ax_client.generation_strategy._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertEqual(len(node1_max_parallelism), 0) + self.assertEqual(len(node1_max_concurrency), 0) for _ in range(10): ax_client.get_next_trial() @@ -1939,17 +1939,17 @@ def test_relative_oc_without_sq(self) -> None: def test_recommended_parallelism(self) -> None: ax_client = AxClient() with self.assertRaisesRegex(AssertionError, "No generation strategy"): - ax_client.get_max_parallelism() + ax_client.get_max_concurrency() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)]) + self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)]) self.assertEqual( run_trials_using_recommended_parallelism( - ax_client, ax_client.get_max_parallelism(), 20 + ax_client, ax_client.get_max_concurrency(), 20 ), 0, ) @@ -2872,7 +2872,7 @@ def test_estimate_early_stopping_savings(self) -> None: self.assertEqual(ax_client.estimate_early_stopping_savings(), 0) - def test_max_parallelism_exception_when_early_stopping(self) -> None: + def test_max_concurrency_exception_when_early_stopping(self) -> None: ax_client = AxClient() ax_client.create_experiment( parameters=[ From dfba2e6fa4574b48e8b11edae7ecda78982052ca Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 19 Feb 2026 13:02:43 -0800 Subject: [PATCH 3/6] Rename num_parallel_jobs to num_concurrent_jobs in BenchmarkExecutionSettings Summary: Renames `num_parallel_jobs` to `num_concurrent_jobs` in `BenchmarkExecutionSettings` and all nightly benchmark configurations. Also updates the docstring in `BenchmarkMethod` to reference "pending trials" instead of "parallelism". This is a mechanical rename with no behavioral change. Differential Revision: D93771883 --- ax/benchmark/benchmark_method.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ax/benchmark/benchmark_method.py b/ax/benchmark/benchmark_method.py index bea5463bd2f..19ec4631ce8 100644 --- a/ax/benchmark/benchmark_method.py +++ b/ax/benchmark/benchmark_method.py @@ -16,7 +16,7 @@ class BenchmarkMethod(Base): """Benchmark method, represented in terms of Ax generation strategy (which tells us which models to use when) and Orchestrator options (which tell us extra execution - information like maximum parallelism, early stopping configuration, etc.). + information like maximum pending trials, early stopping configuration, etc.). Args: name: String description. Defaults to the name of the generation strategy. From 326e2388f51200af46a9a4e47df52b0ad628c070 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 19 Feb 2026 13:03:22 -0800 Subject: [PATCH 4/6] Rename parallelism references to concurrency in docs, comments, constants, and telemetry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Updates remaining references from "parallelism" to "concurrency" across orchestration, telemetry, early stopping, and other modules. This covers docstrings, comments, constant names (`MAX_PENDING_TRIALS` → `MAX_CONCURRENT_TRIALS`, `DUMMY_MAX_PENDING_TRIALS` → `DUMMY_MAX_CONCURRENT_TRIALS`), telemetry field names, and variable names in test files. No behavioral changes — purely a terminology alignment. Differential Revision: D93771906 --- ax/analysis/healthcheck/early_stopping_healthcheck.py | 6 +++--- ax/core/runner.py | 2 +- ax/early_stopping/experiment_replay.py | 6 +++--- ax/orchestration/orchestrator.py | 7 ++++--- ax/orchestration/orchestrator_options.py | 2 +- ax/utils/common/complexity_utils.py | 2 +- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ax/analysis/healthcheck/early_stopping_healthcheck.py b/ax/analysis/healthcheck/early_stopping_healthcheck.py index 49bfca2e30c..43b63de51e1 100644 --- a/ax/analysis/healthcheck/early_stopping_healthcheck.py +++ b/ax/analysis/healthcheck/early_stopping_healthcheck.py @@ -23,7 +23,7 @@ from ax.early_stopping.dispatch import get_default_ess_or_none from ax.early_stopping.experiment_replay import ( estimate_hypothetical_early_stopping_savings, - MAX_PENDING_TRIALS, + MAX_CONCURRENT_TRIALS, MIN_SAVINGS_THRESHOLD, ) from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy @@ -81,7 +81,7 @@ def __init__( self, early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, min_savings_threshold: float = MIN_SAVINGS_THRESHOLD, - max_pending_trials: int = MAX_PENDING_TRIALS, + max_pending_trials: int = MAX_CONCURRENT_TRIALS, auto_early_stopping_config: AutoEarlyStoppingConfig | None = None, nudge_additional_info: str | None = None, ) -> None: @@ -95,7 +95,7 @@ def __init__( single-objective unconstrained experiments. min_savings_threshold: Minimum savings threshold to suggest early stopping. Default is 0.1 (10% savings). - max_pending_trials: Maximum number of pending trials for replay + max_pending_trials: Maximum number of concurrent trials for replay orchestrator. Default is 5. auto_early_stopping_config: A string for configuring automated early stopping strategy. diff --git a/ax/core/runner.py b/ax/core/runner.py index 0fcf8abff20..c33ca9d9e66 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -81,7 +81,7 @@ def poll_available_capacity(self) -> int: artificially force this method to limit capacity; ``Orchestrator`` has other limitations in place to limit number of trials running at once, like the ``OrchestratorOptions.max_pending_trials`` setting, or - more granular control in the form of the `max_parallelism` + more granular control in the form of the `max_concurrency` setting in each of the `GenerationStep`s of a `GenerationStrategy`). Returns: diff --git a/ax/early_stopping/experiment_replay.py b/ax/early_stopping/experiment_replay.py index d4a1f5a5fe2..ab39d731081 100644 --- a/ax/early_stopping/experiment_replay.py +++ b/ax/early_stopping/experiment_replay.py @@ -35,7 +35,7 @@ # Constants for experiment replay MAX_REPLAY_TRIALS: int = 50 REPLAY_NUM_POINTS_PER_CURVE: int = 20 -MAX_PENDING_TRIALS: int = 5 +MAX_CONCURRENT_TRIALS: int = 5 MIN_SAVINGS_THRESHOLD: float = 0.1 # 10% threshold @@ -119,7 +119,7 @@ def replay_experiment( def estimate_hypothetical_early_stopping_savings( experiment: Experiment, metric: Metric, - max_pending_trials: int = MAX_PENDING_TRIALS, + max_pending_trials: int = MAX_CONCURRENT_TRIALS, ) -> float: """Estimate hypothetical early stopping savings using experiment replay. @@ -130,7 +130,7 @@ def estimate_hypothetical_early_stopping_savings( Args: experiment: The experiment to analyze. metric: The metric to use for early stopping replay. - max_pending_trials: Maximum number of pending trials for the replay + max_pending_trials: Maximum number of concurrent trials for the replay orchestrator. Defaults to 5. Returns: diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index a56357a3b4c..19a07cd10a3 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -1645,7 +1645,7 @@ def _validate_options(self, options: OrchestratorOptions) -> None: ) def _get_max_pending_trials(self) -> int: - """Returns the maximum number of pending trials specified in the options, or + """Returns the maximum number of concurrent trials specified in the options, or zero, if the failure rate limit has been exceeded at any point during the optimization. """ @@ -1690,8 +1690,9 @@ def _prepare_trials( max_pending_upper_bound = max_pending_trials - num_pending_trials if max_pending_upper_bound < 1: self.logger.debug( - f"`max_pending_trials={max_pending_trials}` and {num_pending_trials} " - "trials are currently pending; not initiating any additional trials." + f"`max_pending_trials={max_pending_trials}` and " + f"{num_pending_trials} trials are currently pending; " + "not initiating any additional trials." ) return [], [] n = max_pending_upper_bound if n == -1 else min(max_pending_upper_bound, n) diff --git a/ax/orchestration/orchestrator_options.py b/ax/orchestration/orchestrator_options.py index 70ab3b04cf3..184db46fedc 100644 --- a/ax/orchestration/orchestrator_options.py +++ b/ax/orchestration/orchestrator_options.py @@ -90,7 +90,7 @@ class OrchestratorOptions: deployment. The size of the groups will be determined as the minimum of ``self.poll_available_capacity()`` and the number of generator runs that the generation strategy is able to produce - without more data or reaching its allowed max paralellism limit. + without more data or reaching its allowed max concurrency limit. debug_log_run_metadata: Whether to log run_metadata for debugging purposes. early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines whether a trial should be stopped given the current state of diff --git a/ax/utils/common/complexity_utils.py b/ax/utils/common/complexity_utils.py index a0edc391df2..9a3028ca76e 100644 --- a/ax/utils/common/complexity_utils.py +++ b/ax/utils/common/complexity_utils.py @@ -111,7 +111,7 @@ class OptimizationSummary: is True). tolerated_trial_failure_rate: Maximum tolerated trial failure rate (should be <= 0.9). - max_pending_trials: Maximum number of pending trials. + max_pending_trials: Maximum number of concurrent trials. min_failed_trials_for_failure_rate_check: Minimum failed trials before failure rate is checked. non_default_advanced_options: Whether non-default advanced options are set. From d600c9ec1e472b94aead582f488256e958ea5c18 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 19 Feb 2026 13:04:00 -0800 Subject: [PATCH 5/6] Deal with `total_concurrency` and `n` weirdness: introduce `ExperimentDesign.concurrency_limit` Summary: As titled, adding a simple `ExperimentDesign` object. Putting it into properties for serialization for now, so as to not do duplicate work ahead of the storage refactor implementation (and also in case we change things while working on this stack). Differential Revision: D89770462 --- ax/core/experiment.py | 17 ++++++ ax/core/experiment_design.py | 34 ++++++++++++ ax/core/tests/test_experiment_design.py | 33 ++++++++++++ ax/storage/json_store/encoders.py | 10 +++- .../json_store/tests/test_json_store.py | 54 +++++++++++++++++++ ax/storage/sqa_store/encoder.py | 7 +++ ax/storage/sqa_store/tests/test_sqa_store.py | 27 ++++++++++ 7 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 ax/core/experiment_design.py create mode 100644 ax/core/tests/test_experiment_design.py diff --git a/ax/core/experiment.py b/ax/core/experiment.py index fd1df933594..bdb25b87026 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -30,6 +30,7 @@ from ax.core.base_trial import BaseTrial from ax.core.batch_trial import BatchTrial from ax.core.data import combine_data_rows_favoring_recent, Data +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.llm_provider import LLMMessage @@ -152,6 +153,17 @@ def __init__( self._status: ExperimentStatus | None = None self._trials: dict[int, BaseTrial] = {} self._properties: dict[str, Any] = properties or {} + self._design: ExperimentDesign = ExperimentDesign() + # Properties is temporarily being used to serialize/deserialize experiment + # in fbcode/ax/storage/sqa_store/encoder.py. + # Eventually design will have first-class support as an experiment + # attribute + # TODO[drfreund, mpolson64]: Replace with proper storage as part of the + # refactor. + if ( + design_dict := self._properties.pop(EXPERIMENT_DESIGN_KEY, None) + ) is not None: + self._design.concurrency_limit = design_dict.get("concurrency_limit") # Initialize trial type to runner mapping self._default_trial_type = default_trial_type @@ -294,6 +306,11 @@ def experiment_status_from_generator_runs( return suggested_statuses.pop() + @property + def design(self) -> ExperimentDesign: + """The experiment design configuration.""" + return self._design + @property def search_space(self) -> SearchSpace: """The search space for this experiment. diff --git a/ax/core/experiment_design.py b/ax/core/experiment_design.py new file mode 100644 index 00000000000..1533bfd9cf0 --- /dev/null +++ b/ax/core/experiment_design.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +EXPERIMENT_DESIGN_KEY: str = "experiment_design" + + +@dataclass +class ExperimentDesign: + """Struct that holds "experiment design" configuration: these are + experiment-level settings that pertain to "how the experiment will be + run or conducted", but are agnostic to the specific evaluation + backend, to which the trials will be deployed. + + NOTE: In the future, we might treat concurrency limit as expressed + in terms of "full arm equivalents" as opposed to just "number of arms", + to cover for the multi-fidelity cases. + + NOTE: in ax/storage/sqa_store/encoder.py, attributes of this class + are automatically serialized and stored in experiment.properties + + Args: + concurrency_limit: Maximum number of arms to run within one or + multiple trials, in parallel. In experiments that consist of + Trials, this is equivalent to the total number of trials + that should run in parallel. In experiments with BatchTrials, + total number of arms can be spread across one or + multiple BatchTrials. + """ + + concurrency_limit: int | None = None diff --git a/ax/core/tests/test_experiment_design.py b/ax/core/tests/test_experiment_design.py new file mode 100644 index 00000000000..6af6ed3d004 --- /dev/null +++ b/ax/core/tests/test_experiment_design.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict + +from ax.core import Experiment +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_search_space + + +class ExperimentDesignTest(TestCase): + """Tests covering ExperimentDesign class and its usage in ax Experiment""" + + def test_experiment_design_property(self) -> None: + """Test that Experiment.design property returns ExperimentDesign instance.""" + experiment = Experiment( + name="test", + search_space=get_branin_search_space(), + ) + self.assertIsInstance(experiment.design, ExperimentDesign) + self.assertIsNone(experiment.design.concurrency_limit) + + properties: Dict[str, Any] = {EXPERIMENT_DESIGN_KEY: {"concurrency_limit": 42}} + experiment = Experiment( + name="test", search_space=get_branin_search_space(), properties=properties + ) + self.assertEqual(experiment.design.concurrency_limit, 42) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 87b6b44c045..d6b5e248123 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -16,6 +16,7 @@ from ax.core.auxiliary import AuxiliaryExperiment from ax.core.batch_trial import BatchTrial from ax.core.data import Data +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -76,6 +77,13 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: """Convert Ax experiment to a dictionary.""" + # Serialize ExperimentDesign into properties + properties = { + **experiment._properties, + EXPERIMENT_DESIGN_KEY: { + "concurrency_limit": experiment.design.concurrency_limit, + }, + } return { "__type": experiment.__class__.__name__, "name": experiment._name, @@ -90,7 +98,7 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: "trials": experiment.trials, "is_test": experiment.is_test, "data_by_trial": data_to_data_by_trial(data=experiment.data), - "properties": experiment._properties, + "properties": properties, "_trial_type_to_runner": experiment._trial_type_to_runner, } diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 18156eb7beb..fc32660efd2 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -1660,6 +1660,60 @@ def test_experiment_with_pruning_target_json_roundtrip(self) -> None: ).pruning_target_parameterization, ) + def test_experiment_design_json_roundtrip(self) -> None: + """Test that ExperimentDesign is preserved through JSON serialization.""" + # Setup: create experiment and set concurrency_limit + experiment = get_branin_experiment() + experiment.design.concurrency_limit = 42 + + # Execute: save and load experiment through JSON + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: + save_experiment( + experiment, + f.name, + encoder_registry=CORE_ENCODER_REGISTRY, + class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY, + ) + loaded_experiment = load_experiment( + f.name, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + + # Cleanup + os.remove(f.name) + + # Assert: confirm ExperimentDesign is preserved + self.assertEqual(experiment, loaded_experiment) + self.assertEqual(loaded_experiment.design.concurrency_limit, 42) + + def test_experiment_design_none_concurrency_json_roundtrip(self) -> None: + """Test that ExperimentDesign with None concurrency_limit is preserved.""" + # Setup: create experiment with default (None) concurrency_limit + experiment = get_branin_experiment() + self.assertIsNone(experiment.design.concurrency_limit) + + # Execute: save and load experiment through JSON + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: + save_experiment( + experiment, + f.name, + encoder_registry=CORE_ENCODER_REGISTRY, + class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY, + ) + loaded_experiment = load_experiment( + f.name, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + + # Cleanup + os.remove(f.name) + + # Assert: confirm ExperimentDesign is preserved with None + self.assertEqual(experiment, loaded_experiment) + self.assertIsNone(loaded_experiment.design.concurrency_limit) + def test_multi_objective_from_json_warning(self) -> None: objectives = [get_objective()] diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 87bf58cb88f..4cd25098d9d 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -27,6 +27,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import DataType from ax.core.experiment import Experiment +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -228,6 +229,12 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: elif experiment.runner: runners.append(self.runner_to_sqa(none_throws(experiment.runner))) properties = experiment._properties.copy() + + properties[EXPERIMENT_DESIGN_KEY] = { + attribute: getattr(experiment.design, attribute) + for attribute in list(experiment.design.__dict__.keys()) + } + if ( oc := experiment.optimization_config ) is not None and oc.pruning_target_parameterization is not None: diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 30a915c3a07..72c2b7aa96c 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -424,6 +424,33 @@ def test_arm_parameter_values_cast_to_parameter_type(self) -> None: self.assertEqual(none_throws(status_quo).parameters["y"], 0.5) self.assertIs(type(none_throws(status_quo).parameters["y"]), float) + def test_experiment_design_sqa_roundtrip(self) -> None: + """Test that ExperimentDesign is preserved through SQA serialization.""" + # Create experiment and set concurrency_limit + experiment = get_experiment_with_batch_trial() + experiment.design.concurrency_limit = 42 + + # Save and load experiment through SQA + save_experiment(experiment) + loaded_experiment = load_experiment(experiment.name) + + # Verify ExperimentDesign, w/ concurrency value 42 is preserved + self.assertEqual(loaded_experiment, experiment) + self.assertEqual(loaded_experiment.design.concurrency_limit, 42) + + # Repeat process for none value for concurrency_limit + experiment = get_experiment_with_batch_trial() + experiment.name = "experiment_design_none_concurrency_test" + self.assertIsNone(experiment.design.concurrency_limit) + + # Save and load experiment through SQA + save_experiment(experiment) + loaded_experiment = load_experiment(experiment.name) + + # Verify ExperimentDesign with None is preserved + self.assertEqual(loaded_experiment, experiment) + self.assertIsNone(loaded_experiment.design.concurrency_limit) + def test_saving_and_loading_experiment_with_aux_exp(self) -> None: aux_experiment = Experiment( name="test_aux_exp_in_SQAStoreTest", From cbcb6485c5bfb562b2db6b6d9bb4e6bb3fa7c962 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Fri, 20 Feb 2026 07:00:33 -0800 Subject: [PATCH 6/6] Use ExperimentDesign.concurrency_limit in Axolotl and GS Summary: Migrates all references from `experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS]` to `experiment.design.concurrency_limit`, completing the transition to the `ExperimentDesign` dataclass introduced in the prior diff. This affects generation node input constructors (including `ALL_N` and `REPEAT_N`), the Axolotl updater, and associated tests. Also cleans up the `no-commit` code in `generation_node_input_constructors.py` to use the new `concurrency_limit` field with a fallback to a default of 10. Differential Revision: D89772029 --- .../generation_node_input_constructors.py | 32 +++++++++++-------- ...test_generation_node_input_constructors.py | 10 +++--- .../tests/test_generation_strategy.py | 2 +- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/ax/generation_strategy/generation_node_input_constructors.py b/ax/generation_strategy/generation_node_input_constructors.py index 325d485bd64..cff09b0346b 100644 --- a/ax/generation_strategy/generation_node_input_constructors.py +++ b/ax/generation_strategy/generation_node_input_constructors.py @@ -264,19 +264,23 @@ def _get_default_n(experiment: Experiment, next_node: GenerationNode) -> int: if next_node.generator_spec_to_gen_from.generator_gen_kwargs.get("n") is not None: return next_node.generator_spec_to_gen_from.generator_gen_kwargs["n"] - if ( - exp_n := experiment._properties.get(Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS.value) - ) is None: - # GS default n is 1, but these input constructors are used for nodes that - # should generate more than 1 arm per trial, default to 10 - return 10 + n_for_this_trial = None + concurrency_limit = experiment.design.concurrency_limit - # If exp_n is set, we will use that in conjunction with the trial + # If concurrency_limit is set, we will use that in conjunction with the trial # type to determine the number of arms to generate from this node - # TODO #2 [drfreund, mgarrard]: Instead of this, short- and long-run nodes should - # just have different input constructors (`half_n_floor` and `half_n_ceil`). - if next_node._trial_type == Keys.SHORT_RUN: - return floor(0.5 * exp_n) - if next_node._trial_type == Keys.LONG_RUN: - return ceil(0.5 * exp_n) - return exp_n + if concurrency_limit is not None: + if next_node._trial_type == Keys.SHORT_RUN: + n_for_this_trial = floor(0.5 * concurrency_limit) + elif next_node._trial_type == Keys.LONG_RUN: + n_for_this_trial = ceil(0.5 * concurrency_limit) + else: + n_for_this_trial = concurrency_limit + + return ( + n_for_this_trial + if n_for_this_trial is not None + # GS default n is 1, but these input constructors are used for nodes that + # should generate more than 1 arm per trial, default to 10 + else 10 + ) diff --git a/ax/generation_strategy/tests/test_generation_node_input_constructors.py b/ax/generation_strategy/tests/test_generation_node_input_constructors.py index 88ee58eeee2..e0f3aad94af 100644 --- a/ax/generation_strategy/tests/test_generation_node_input_constructors.py +++ b/ax/generation_strategy/tests/test_generation_node_input_constructors.py @@ -192,7 +192,7 @@ def test_no_n_provided_all_n(self) -> None: self.assertEqual(num_to_gen, 10) def test_no_n_provided_all_n_with_exp_prop(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 12 + self.experiment._design.concurrency_limit = 12 num_to_gen = NodeInputConstructors.ALL_N( previous_node=None, next_node=self.sobol_generation_node, @@ -202,7 +202,7 @@ def test_no_n_provided_all_n_with_exp_prop(self) -> None: self.assertEqual(num_to_gen, 12) def test_no_n_provided_all_n_with_exp_prop_long_run(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 13 + self.experiment._design.concurrency_limit = 13 self.sobol_generation_node._trial_type = Keys.LONG_RUN num_to_gen = NodeInputConstructors.ALL_N( previous_node=None, @@ -213,7 +213,7 @@ def test_no_n_provided_all_n_with_exp_prop_long_run(self) -> None: self.assertEqual(num_to_gen, 7) def test_no_n_provided_all_n_with_exp_prop_short_run(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 13 + self.experiment._design.concurrency_limit = 13 self.sobol_generation_node._trial_type = Keys.SHORT_RUN num_to_gen = NodeInputConstructors.ALL_N( previous_node=None, @@ -233,7 +233,7 @@ def test_no_n_provided_repeat_n(self) -> None: self.assertEqual(num_to_gen, 1) def test_no_n_provided_repeat_n_with_exp_prop(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 18 + self.experiment._design.concurrency_limit = 18 num_to_gen = NodeInputConstructors.REPEAT_N( previous_node=None, next_node=self.sobol_generation_node, @@ -243,7 +243,7 @@ def test_no_n_provided_repeat_n_with_exp_prop(self) -> None: self.assertEqual(num_to_gen, 2) def test_no_n_provided_repeat_n_with_exp_prop_long_run(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 18 + self.experiment._design.concurrency_limit = 18 self.sobol_generation_node._trial_type = Keys.SHORT_RUN num_to_gen = NodeInputConstructors.REPEAT_N( previous_node=None, diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 03747b127d8..226445b460d 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -1202,7 +1202,7 @@ def test_gen_for_multiple_uses_total_concurrent_arms_for_a_default( self.sobol_node._transition_criteria = [] gs = GenerationStrategy(nodes=[self.sobol_node], name="test") gs.experiment = exp - exp._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS.value] = 3 + exp._design.concurrency_limit = 3 grs = gs.gen(exp, num_trials=2) self.assertEqual(len(grs), 2) for gr_list in grs: