diff --git a/ax/analysis/healthcheck/early_stopping_healthcheck.py b/ax/analysis/healthcheck/early_stopping_healthcheck.py index 49bfca2e30c..43b63de51e1 100644 --- a/ax/analysis/healthcheck/early_stopping_healthcheck.py +++ b/ax/analysis/healthcheck/early_stopping_healthcheck.py @@ -23,7 +23,7 @@ from ax.early_stopping.dispatch import get_default_ess_or_none from ax.early_stopping.experiment_replay import ( estimate_hypothetical_early_stopping_savings, - MAX_PENDING_TRIALS, + MAX_CONCURRENT_TRIALS, MIN_SAVINGS_THRESHOLD, ) from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy @@ -81,7 +81,7 @@ def __init__( self, early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, min_savings_threshold: float = MIN_SAVINGS_THRESHOLD, - max_pending_trials: int = MAX_PENDING_TRIALS, + max_pending_trials: int = MAX_CONCURRENT_TRIALS, auto_early_stopping_config: AutoEarlyStoppingConfig | None = None, nudge_additional_info: str | None = None, ) -> None: @@ -95,7 +95,7 @@ def __init__( single-objective unconstrained experiments. min_savings_threshold: Minimum savings threshold to suggest early stopping. Default is 0.1 (10% savings). - max_pending_trials: Maximum number of pending trials for replay + max_pending_trials: Maximum number of concurrent trials for replay orchestrator. Default is 5. auto_early_stopping_config: A string for configuring automated early stopping strategy. diff --git a/ax/api/client.py b/ax/api/client.py index 8b1cf583fe8..b5e2c1ce488 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -6,6 +6,7 @@ # pyre-strict import json +import warnings from collections.abc import Iterable, Sequence from logging import Logger from typing import Any, Literal, Self @@ -43,7 +44,7 @@ BaseEarlyStoppingStrategy, PercentileEarlyStoppingStrategy, ) -from ax.exceptions.core import ObjectNotFoundError, UnsupportedError +from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions from ax.service.utils.best_point_mixin import BestPointMixin @@ -710,9 +711,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None: def run_trials( self, max_trials: int, - parallelism: int = 1, + concurrency: int = 1, tolerated_trial_failure_rate: float = 0.5, initial_seconds_between_polls: int = 1, + # Deprecated argument for backwards compatibility. + parallelism: int | None = None, ) -> None: """ Run maximum_trials trials in a loop by creating an ephemeral Orchestrator under @@ -721,12 +724,25 @@ def run_trials( Saves to database on completion if ``storage_config`` is present. """ + # Handle deprecated `parallelism` argument. + if parallelism is not None: + warnings.warn( + "`parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + if concurrency != 1: + raise UserInputError( + "Cannot specify both `parallelism` and `concurrency`." + ) + concurrency = parallelism orchestrator = Orchestrator( experiment=self._experiment, generation_strategy=self._generation_strategy_or_choose(), options=OrchestratorOptions( - max_pending_trials=parallelism, + max_pending_trials=concurrency, tolerated_trial_failure_rate=tolerated_trial_failure_rate, init_seconds_between_polls=initial_seconds_between_polls, ), diff --git a/ax/benchmark/benchmark_method.py b/ax/benchmark/benchmark_method.py index bea5463bd2f..19ec4631ce8 100644 --- a/ax/benchmark/benchmark_method.py +++ b/ax/benchmark/benchmark_method.py @@ -16,7 +16,7 @@ class BenchmarkMethod(Base): """Benchmark method, represented in terms of Ax generation strategy (which tells us which models to use when) and Orchestrator options (which tell us extra execution - information like maximum parallelism, early stopping configuration, etc.). + information like maximum pending trials, early stopping configuration, etc.). Args: name: String description. Defaults to the name of the generation strategy. diff --git a/ax/core/experiment.py b/ax/core/experiment.py index fd1df933594..bdb25b87026 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -30,6 +30,7 @@ from ax.core.base_trial import BaseTrial from ax.core.batch_trial import BatchTrial from ax.core.data import combine_data_rows_favoring_recent, Data +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.llm_provider import LLMMessage @@ -152,6 +153,17 @@ def __init__( self._status: ExperimentStatus | None = None self._trials: dict[int, BaseTrial] = {} self._properties: dict[str, Any] = properties or {} + self._design: ExperimentDesign = ExperimentDesign() + # Properties is temporarily being used to serialize/deserialize experiment + # in fbcode/ax/storage/sqa_store/encoder.py. + # Eventually design will have first-class support as an experiment + # attribute + # TODO[drfreund, mpolson64]: Replace with proper storage as part of the + # refactor. + if ( + design_dict := self._properties.pop(EXPERIMENT_DESIGN_KEY, None) + ) is not None: + self._design.concurrency_limit = design_dict.get("concurrency_limit") # Initialize trial type to runner mapping self._default_trial_type = default_trial_type @@ -294,6 +306,11 @@ def experiment_status_from_generator_runs( return suggested_statuses.pop() + @property + def design(self) -> ExperimentDesign: + """The experiment design configuration.""" + return self._design + @property def search_space(self) -> SearchSpace: """The search space for this experiment. diff --git a/ax/core/experiment_design.py b/ax/core/experiment_design.py new file mode 100644 index 00000000000..1533bfd9cf0 --- /dev/null +++ b/ax/core/experiment_design.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +EXPERIMENT_DESIGN_KEY: str = "experiment_design" + + +@dataclass +class ExperimentDesign: + """Struct that holds "experiment design" configuration: these are + experiment-level settings that pertain to "how the experiment will be + run or conducted", but are agnostic to the specific evaluation + backend, to which the trials will be deployed. + + NOTE: In the future, we might treat concurrency limit as expressed + in terms of "full arm equivalents" as opposed to just "number of arms", + to cover for the multi-fidelity cases. + + NOTE: in ax/storage/sqa_store/encoder.py, attributes of this class + are automatically serialized and stored in experiment.properties + + Args: + concurrency_limit: Maximum number of arms to run within one or + multiple trials, in parallel. In experiments that consist of + Trials, this is equivalent to the total number of trials + that should run in parallel. In experiments with BatchTrials, + total number of arms can be spread across one or + multiple BatchTrials. + """ + + concurrency_limit: int | None = None diff --git a/ax/core/runner.py b/ax/core/runner.py index 0fcf8abff20..c33ca9d9e66 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -81,7 +81,7 @@ def poll_available_capacity(self) -> int: artificially force this method to limit capacity; ``Orchestrator`` has other limitations in place to limit number of trials running at once, like the ``OrchestratorOptions.max_pending_trials`` setting, or - more granular control in the form of the `max_parallelism` + more granular control in the form of the `max_concurrency` setting in each of the `GenerationStep`s of a `GenerationStrategy`). Returns: diff --git a/ax/core/tests/test_experiment_design.py b/ax/core/tests/test_experiment_design.py new file mode 100644 index 00000000000..6af6ed3d004 --- /dev/null +++ b/ax/core/tests/test_experiment_design.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict + +from ax.core import Experiment +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_search_space + + +class ExperimentDesignTest(TestCase): + """Tests covering ExperimentDesign class and its usage in ax Experiment""" + + def test_experiment_design_property(self) -> None: + """Test that Experiment.design property returns ExperimentDesign instance.""" + experiment = Experiment( + name="test", + search_space=get_branin_search_space(), + ) + self.assertIsInstance(experiment.design, ExperimentDesign) + self.assertIsNone(experiment.design.concurrency_limit) + + properties: Dict[str, Any] = {EXPERIMENT_DESIGN_KEY: {"concurrency_limit": 42}} + experiment = Experiment( + name="test", search_space=get_branin_search_space(), properties=properties + ) + self.assertEqual(experiment.design.concurrency_limit, 42) diff --git a/ax/early_stopping/experiment_replay.py b/ax/early_stopping/experiment_replay.py index d4a1f5a5fe2..ab39d731081 100644 --- a/ax/early_stopping/experiment_replay.py +++ b/ax/early_stopping/experiment_replay.py @@ -35,7 +35,7 @@ # Constants for experiment replay MAX_REPLAY_TRIALS: int = 50 REPLAY_NUM_POINTS_PER_CURVE: int = 20 -MAX_PENDING_TRIALS: int = 5 +MAX_CONCURRENT_TRIALS: int = 5 MIN_SAVINGS_THRESHOLD: float = 0.1 # 10% threshold @@ -119,7 +119,7 @@ def replay_experiment( def estimate_hypothetical_early_stopping_savings( experiment: Experiment, metric: Metric, - max_pending_trials: int = MAX_PENDING_TRIALS, + max_pending_trials: int = MAX_CONCURRENT_TRIALS, ) -> float: """Estimate hypothetical early stopping savings using experiment replay. @@ -130,7 +130,7 @@ def estimate_hypothetical_early_stopping_savings( Args: experiment: The experiment to analyze. metric: The metric to use for early stopping replay. - max_pending_trials: Maximum number of pending trials for the replay + max_pending_trials: Maximum number of concurrent trials for the replay orchestrator. Defaults to 5. Returns: diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index f2721d47fc6..52326c66510 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -7,6 +7,7 @@ # pyre-strict import logging +import warnings from math import ceil from typing import Any @@ -16,6 +17,7 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, @@ -31,7 +33,7 @@ logger: logging.Logger = get_logger(__name__) -DEFAULT_BAYESIAN_PARALLELISM = 3 +DEFAULT_BAYESIAN_CONCURRENCY = 3 # `BO_MIXED` optimizes all range parameters once for each combination of choice # parameters, then takes the optimum of those optima. The cost associated with this # method grows with the number of combinations, and so it is only used when the @@ -49,7 +51,7 @@ def _make_sobol_step( num_trials: int = -1, min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: int | None = None, + max_concurrency: int | None = None, seed: int | None = None, should_deduplicate: bool = False, ) -> GenerationStep: @@ -62,7 +64,7 @@ def _make_sobol_step( ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed ), enforce_num_trials=enforce_num_trials, - max_parallelism=max_parallelism, + max_concurrency=max_concurrency, generator_kwargs={"deduplicate": True, "seed": seed}, should_deduplicate=should_deduplicate, use_all_trials_in_exp=True, @@ -73,7 +75,7 @@ def _make_botorch_step( num_trials: int = -1, min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: int | None = None, + max_concurrency: int | None = None, generator: GeneratorRegistryBase = Generators.BOTORCH_MODULAR, generator_kwargs: dict[str, Any] | None = None, winsorization_config: None @@ -130,7 +132,7 @@ def _make_botorch_step( ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed ), enforce_num_trials=enforce_num_trials, - max_parallelism=max_parallelism, + max_concurrency=max_concurrency, generator_kwargs=generator_kwargs, should_deduplicate=should_deduplicate, ) @@ -300,8 +302,8 @@ def choose_generation_strategy_legacy( num_completed_initialization_trials: int = 0, max_initialization_trials: int | None = None, min_sobol_trials_observed: int | None = None, - max_parallelism_cap: int | None = None, - max_parallelism_override: int | None = None, + max_concurrency_cap: int | None = None, + max_concurrency_override: int | None = None, optimization_config: OptimizationConfig | None = None, should_deduplicate: bool = False, use_saasbo: bool = False, @@ -311,6 +313,9 @@ def choose_generation_strategy_legacy( suggested_model_override: GeneratorRegistryBase | None = None, use_input_warping: bool = False, simplify_parameter_changes: bool = False, + # Deprecated arguments for backwards compatibility. + max_parallelism_cap: int | None = None, + max_parallelism_override: int | None = None, ) -> GenerationStrategy: """Select an appropriate generation strategy based on the properties of the search space and expected settings of the experiment, such as number of @@ -325,11 +330,11 @@ def choose_generation_strategy_legacy( enforce_sequential_optimization: Whether to enforce that 1) the generation strategy needs to be updated with ``min_trials_observed`` observations for a given generation step before proceeding to the next one and 2) maximum - number of trials running at once (max_parallelism) if enforced for the - BayesOpt step. NOTE: ``max_parallelism_override`` and - ``max_parallelism_cap`` settings will still take their effect on max - parallelism even if ``enforce_sequential_optimization=False``, so if those - settings are specified, max parallelism will be enforced. + number of trials running at once (max_concurrency) if enforced for the + BayesOpt step. NOTE: ``max_concurrency_override`` and + ``max_concurrency_cap`` settings will still take their effect on max + concurrency even if ``enforce_sequential_optimization=False``, so if those + settings are specified, max concurrency will be enforced. random_seed: Fixed random seed for the Sobol generator. torch_device: The device to use for generation steps implemented in PyTorch (e.g. via BoTorch). Some generation steps (in particular EHVI-based ones @@ -360,21 +365,21 @@ def choose_generation_strategy_legacy( min_sobol_trials_observed: Minimum number of Sobol trials that must be observed before proceeding to the next generation step. Defaults to `ceil(num_initialization_trials / 2)`. - max_parallelism_cap: Integer cap on parallelism in this generation strategy. - If specified, ``max_parallelism`` setting in each generation step will be + max_concurrency_cap: Integer cap on concurrency in this generation strategy. + If specified, ``max_concurrency`` setting in each generation step will be set to the minimum of the default setting for that step and the value of - this cap. ``max_parallelism_cap`` is meant to just be a hard limit on - parallelism (e.g. to avoid overloading machine(s) that evaluate the + this cap. ``max_concurrency_cap`` is meant to just be a hard limit on + concurrency (e.g. to avoid overloading machine(s) that evaluate the experiment trials). Specify only if not specifying - ``max_parallelism_override``. - max_parallelism_override: Integer, with which to override the default max - parallelism setting for all steps in the generation strategy returned from - this function. Each generation step has a ``max_parallelism`` value, which + ``max_concurrency_override``. + max_concurrency_override: Integer, with which to override the default max + concurrency setting for all steps in the generation strategy returned from + this function. Each generation step has a ``max_concurrency`` value, which restricts how many trials can run simultaneously during a given generation - step. By default, the parallelism setting is chosen as appropriate for the - model in a given generation step. If ``max_parallelism_override`` is -1, - no max parallelism will be enforced for any step of the generation - strategy. Be aware that parallelism is limited to improve performance of + step. By default, the concurrency setting is chosen as appropriate for the + model in a given generation step. If ``max_concurrency_override`` is -1, + no max concurrency will be enforced for any step of the generation + strategy. Be aware that concurrency is limited to improve performance of Bayesian optimization, so only disable its limiting if necessary. optimization_config: Used to infer whether to use MOO. should_deduplicate: Whether to deduplicate the parameters of proposed arms @@ -407,6 +412,34 @@ def choose_generation_strategy_legacy( simplify parameter changes in arms generated via Bayesian Optimization by pruning irrelevant parameter changes. """ + # Handle deprecated arguments. + if max_parallelism_cap is not None: + warnings.warn( + "`max_parallelism_cap` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency_cap` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency_cap is not None: + raise UserInputError( + "Cannot specify both `max_parallelism_cap` and `max_concurrency_cap`." + ) + max_concurrency_cap = max_parallelism_cap + + if max_parallelism_override is not None: + warnings.warn( + "`max_parallelism_override` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency_override` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency_override is not None: + raise UserInputError( + "Cannot specify both `max_parallelism_override` and " + "`max_concurrency_override`." + ) + max_concurrency_override = max_parallelism_override + if experiment is not None and optimization_config is None: optimization_config = experiment.optimization_config @@ -416,36 +449,36 @@ def choose_generation_strategy_legacy( optimization_config=optimization_config, use_saasbo=use_saasbo, ) - # Determine max parallelism for the generation steps. - if max_parallelism_override == -1: - # `max_parallelism_override` of -1 means no max parallelism enforcement in - # the generation strategy, which means `max_parallelism=None` in gen. steps. - sobol_parallelism = bo_parallelism = None - elif max_parallelism_override is not None: - sobol_parallelism = bo_parallelism = max_parallelism_override - elif max_parallelism_cap is not None: # Max parallelism override is None by now - sobol_parallelism = max_parallelism_cap - bo_parallelism = min(max_parallelism_cap, DEFAULT_BAYESIAN_PARALLELISM) + # Determine max concurrency for the generation steps. + if max_concurrency_override == -1: + # `max_concurrency_override` of -1 means no max concurrency enforcement in + # the generation strategy, which means `max_concurrency=None` in gen. steps. + sobol_concurrency = bo_concurrency = None + elif max_concurrency_override is not None: + sobol_concurrency = bo_concurrency = max_concurrency_override + elif max_concurrency_cap is not None: # Max concurrency override is None by now + sobol_concurrency = max_concurrency_cap + bo_concurrency = min(max_concurrency_cap, DEFAULT_BAYESIAN_CONCURRENCY) elif not enforce_sequential_optimization: - # If no max parallelism settings specified and not enforcing sequential - # optimization, do not limit parallelism. - sobol_parallelism = bo_parallelism = None - else: # No additional max parallelism settings, use defaults - sobol_parallelism = None # No restriction on Sobol phase - bo_parallelism = DEFAULT_BAYESIAN_PARALLELISM + # If no max concurrency settings specified and not enforcing sequential + # optimization, do not limit concurrency. + sobol_concurrency = bo_concurrency = None + else: # No additional max concurrency settings, use defaults + sobol_concurrency = None # No restriction on Sobol phase + bo_concurrency = DEFAULT_BAYESIAN_CONCURRENCY if not force_random_search and suggested_model is not None: if not enforce_sequential_optimization and ( - max_parallelism_override is not None or max_parallelism_cap is not None + max_concurrency_override is not None or max_concurrency_cap is not None ): logger.info( - "If `enforce_sequential_optimization` is False, max parallelism is " - "not enforced and other max parallelism settings will be ignored." + "If `enforce_sequential_optimization` is False, max concurrency is " + "not enforced and other max concurrency settings will be ignored." ) - if max_parallelism_override is not None and max_parallelism_cap is not None: + if max_concurrency_override is not None and max_concurrency_cap is not None: raise ValueError( - "If `max_parallelism_override` specified, cannot also apply " - "`max_parallelism_cap`." + "If `max_concurrency_override` specified, cannot also apply " + "`max_concurrency_cap`." ) # If number of initialization trials is not specified, estimate it. @@ -503,7 +536,7 @@ def choose_generation_strategy_legacy( min_trials_observed=min_sobol_trials_observed, enforce_num_trials=enforce_sequential_optimization, seed=random_seed, - max_parallelism=sobol_parallelism, + max_concurrency=sobol_concurrency, should_deduplicate=should_deduplicate, ) ) @@ -512,7 +545,7 @@ def choose_generation_strategy_legacy( generator=suggested_model, winsorization_config=winsorization_config, derelativize_with_raw_status_quo=derelativize_with_raw_status_quo, - max_parallelism=bo_parallelism, + max_concurrency=bo_concurrency, generator_kwargs=generator_kwargs, should_deduplicate=should_deduplicate, disable_progbar=disable_progbar, @@ -544,7 +577,7 @@ def choose_generation_strategy_legacy( _make_sobol_step( seed=random_seed, should_deduplicate=should_deduplicate, - max_parallelism=sobol_parallelism, + max_concurrency=sobol_concurrency, ) ] ) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 435bccc9f8d..428e36560a5 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -8,6 +8,7 @@ from __future__ import annotations +import warnings from collections import defaultdict from collections.abc import Sequence from logging import Logger @@ -971,9 +972,9 @@ class GenerationStep: If `num_trials` of a given step have been generated but `min_trials_ observed` have not been completed, a call to `generation_strategy.gen` will fail with a `DataRequiredError`. - max_parallelism: How many trials generated in the course of this step are + max_concurrency: How many trials generated in the course of this step are allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously. - If `max_parallelism` trials from this step are already running, a call + If `max_concurrency` trials from this step are already running, a call to `generation_strategy.gen` will fail with a `MaxParallelismReached Exception`, indicating that more trials need to be completed before generating and running next trials. @@ -1026,7 +1027,7 @@ def __new__( generator_kwargs: dict[str, Any] | None = None, generator_gen_kwargs: dict[str, Any] | None = None, min_trials_observed: int = 0, - max_parallelism: int | None = None, + max_concurrency: int | None = None, enforce_num_trials: bool = True, should_deduplicate: bool = False, generator_name: str | None = None, @@ -1037,6 +1038,7 @@ def __new__( # Deprecated arguments for backwards compatibility. model_kwargs: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, + max_parallelism: int | None = None, # DEPRECATED: use max_concurrency. ) -> GenerationNode: r"""Creates a ``GenerationNode`` configured as a single-model generation step. @@ -1048,15 +1050,29 @@ def __new__( if use_update: raise DeprecationWarning("`GenerationStep.use_update` is deprecated.") + # Handle deprecated `max_parallelism` argument. + if max_parallelism is not None: + warnings.warn( + "`max_parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency is not None: + raise UserInputError( + "Cannot specify both `max_parallelism` and `max_concurrency`." + ) + max_concurrency = max_parallelism + if num_trials < 1 and num_trials != -1: raise UserInputError( "`num_trials` must be positive or -1 (indicating unlimited) " "for all generation steps." ) - if max_parallelism is not None and max_parallelism < 1: + if max_concurrency is not None and max_concurrency < 1: raise UserInputError( - "Maximum parallelism should be None (if no limit) or " - f"a positive number. Got: {max_parallelism} for " + "Maximum concurrency should be None (if no limit) or " + f"a positive number. Got: {max_concurrency} for " f"step {generator_name}." ) @@ -1130,10 +1146,10 @@ def __new__( use_all_trials_in_exp=use_all_trials_in_exp, ) ) - if max_parallelism is not None: + if max_concurrency is not None: transition_criteria.append( MaxGenerationParallelism( - threshold=max_parallelism, + threshold=max_concurrency, transition_to=placeholder_transition_to, only_in_statuses=[TrialStatus.RUNNING], block_gen_if_met=True, diff --git a/ax/generation_strategy/generation_node_input_constructors.py b/ax/generation_strategy/generation_node_input_constructors.py index 325d485bd64..cff09b0346b 100644 --- a/ax/generation_strategy/generation_node_input_constructors.py +++ b/ax/generation_strategy/generation_node_input_constructors.py @@ -264,19 +264,23 @@ def _get_default_n(experiment: Experiment, next_node: GenerationNode) -> int: if next_node.generator_spec_to_gen_from.generator_gen_kwargs.get("n") is not None: return next_node.generator_spec_to_gen_from.generator_gen_kwargs["n"] - if ( - exp_n := experiment._properties.get(Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS.value) - ) is None: - # GS default n is 1, but these input constructors are used for nodes that - # should generate more than 1 arm per trial, default to 10 - return 10 + n_for_this_trial = None + concurrency_limit = experiment.design.concurrency_limit - # If exp_n is set, we will use that in conjunction with the trial + # If concurrency_limit is set, we will use that in conjunction with the trial # type to determine the number of arms to generate from this node - # TODO #2 [drfreund, mgarrard]: Instead of this, short- and long-run nodes should - # just have different input constructors (`half_n_floor` and `half_n_ceil`). - if next_node._trial_type == Keys.SHORT_RUN: - return floor(0.5 * exp_n) - if next_node._trial_type == Keys.LONG_RUN: - return ceil(0.5 * exp_n) - return exp_n + if concurrency_limit is not None: + if next_node._trial_type == Keys.SHORT_RUN: + n_for_this_trial = floor(0.5 * concurrency_limit) + elif next_node._trial_type == Keys.LONG_RUN: + n_for_this_trial = ceil(0.5 * concurrency_limit) + else: + n_for_this_trial = concurrency_limit + + return ( + n_for_this_trial + if n_for_this_trial is not None + # GS default n is 1, but these input constructors are used for nodes that + # should generate more than 1 arm per trial, default to 10 + else 10 + ) diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index c382dd9fe3a..1b9b0803fbf 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -19,7 +19,7 @@ _make_botorch_step, calculate_num_initialization_trials, choose_generation_strategy_legacy, - DEFAULT_BAYESIAN_PARALLELISM, + DEFAULT_BAYESIAN_CONCURRENCY, ) from ax.generation_strategy.generation_node import GenerationNode from ax.generation_strategy.transition_criterion import ( @@ -621,14 +621,14 @@ def test_enforce_sequential_optimization(self) -> None: sobol_gpei._nodes[0].transition_criteria[0], MinTrials ) self.assertTrue(node0_min_trials.block_gen_if_met) - # Check that max_parallelism is set by verifying MaxGenerationParallelism + # Check that max_concurrency is set by verifying MaxGenerationParallelism # criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in sobol_gpei._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertTrue(len(node1_max_parallelism) > 0) + self.assertTrue(len(node1_max_concurrency) > 0) with self.subTest("False"): sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), @@ -646,22 +646,22 @@ def test_enforce_sequential_optimization(self) -> None: sobol_gpei._nodes[0].transition_criteria[0], MinTrials ) self.assertFalse(node0_min_trials.block_gen_if_met) - # Check that max_parallelism is None by verifying no + # Check that max_concurrency is None by verifying no # MaxGenerationParallelism criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in sobol_gpei._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertEqual(len(node1_max_parallelism), 0) - with self.subTest("False and max_parallelism_override"): + self.assertEqual(len(node1_max_concurrency), 0) + with self.subTest("False and max_concurrency_override"): with self.assertLogs( choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_override=5, + max_concurrency_override=5, ) self.assertTrue( any( @@ -670,14 +670,14 @@ def test_enforce_sequential_optimization(self) -> None: ), logger.output, ) - with self.subTest("False and max_parallelism_cap"): + with self.subTest("False and max_concurrency_cap"): with self.assertLogs( choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_cap=5, + max_concurrency_cap=5, ) self.assertTrue( any( @@ -686,27 +686,27 @@ def test_enforce_sequential_optimization(self) -> None: ), logger.output, ) - with self.subTest("False and max_parallelism_override and max_parallelism_cap"): + with self.subTest("False and max_concurrency_override and max_concurrency_cap"): with self.assertRaisesRegex( ValueError, ( - "If `max_parallelism_override` specified, cannot also apply " - "`max_parallelism_cap`." + "If `max_concurrency_override` specified, cannot also apply " + "`max_concurrency_cap`." ), ): choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_override=5, - max_parallelism_cap=5, + max_concurrency_override=5, + max_concurrency_cap=5, ) - def test_max_parallelism_override(self) -> None: + def test_max_concurrency_override(self) -> None: sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=10 + search_space=get_branin_search_space(), max_concurrency_override=10 ) self.assertTrue( - all(self._get_max_parallelism(s) == 10 for s in sobol_gpei._nodes) + all(self._get_max_concurrency(s) == 10 for s in sobol_gpei._nodes) ) def test_winsorization(self) -> None: @@ -817,47 +817,47 @@ def test_fixed_num_initialization_trials(self) -> None: 3, ) - def _get_max_parallelism(self, node: GenerationNode) -> int | None: - """Helper to extract max_parallelism from transition criteria.""" + def _get_max_concurrency(self, node: GenerationNode) -> int | None: + """Helper to extract max_concurrency from transition criteria.""" for tc in node.transition_criteria: if isinstance(tc, MaxGenerationParallelism): return tc.threshold return None - def test_max_parallelism_adjustments(self) -> None: + def test_max_concurrency_adjustments(self) -> None: # No adjustment. sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[0])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[0])) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[1]), - DEFAULT_BAYESIAN_PARALLELISM, + self._get_max_concurrency(sobol_gpei._nodes[1]), + DEFAULT_BAYESIAN_CONCURRENCY, ) # Impose a cap of 1 on max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_cap=1 + search_space=get_branin_search_space(), max_concurrency_cap=1 ) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[0]), + self._get_max_concurrency(sobol_gpei._nodes[0]), 1, ) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[1]), + self._get_max_concurrency(sobol_gpei._nodes[1]), 1, ) # Disable enforcing max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=-1 + search_space=get_branin_search_space(), max_concurrency_override=-1 ) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[0])) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[1])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[0])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[1])) # Override max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=10 + search_space=get_branin_search_space(), max_concurrency_override=10 ) - self.assertEqual(self._get_max_parallelism(sobol_gpei._nodes[0]), 10) - self.assertEqual(self._get_max_parallelism(sobol_gpei._nodes[1]), 10) + self.assertEqual(self._get_max_concurrency(sobol_gpei._nodes[0]), 10) + self.assertEqual(self._get_max_concurrency(sobol_gpei._nodes[1]), 10) def test_set_should_deduplicate(self) -> None: sobol_gpei = choose_generation_strategy_legacy( diff --git a/ax/generation_strategy/tests/test_generation_node_input_constructors.py b/ax/generation_strategy/tests/test_generation_node_input_constructors.py index 88ee58eeee2..e0f3aad94af 100644 --- a/ax/generation_strategy/tests/test_generation_node_input_constructors.py +++ b/ax/generation_strategy/tests/test_generation_node_input_constructors.py @@ -192,7 +192,7 @@ def test_no_n_provided_all_n(self) -> None: self.assertEqual(num_to_gen, 10) def test_no_n_provided_all_n_with_exp_prop(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 12 + self.experiment._design.concurrency_limit = 12 num_to_gen = NodeInputConstructors.ALL_N( previous_node=None, next_node=self.sobol_generation_node, @@ -202,7 +202,7 @@ def test_no_n_provided_all_n_with_exp_prop(self) -> None: self.assertEqual(num_to_gen, 12) def test_no_n_provided_all_n_with_exp_prop_long_run(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 13 + self.experiment._design.concurrency_limit = 13 self.sobol_generation_node._trial_type = Keys.LONG_RUN num_to_gen = NodeInputConstructors.ALL_N( previous_node=None, @@ -213,7 +213,7 @@ def test_no_n_provided_all_n_with_exp_prop_long_run(self) -> None: self.assertEqual(num_to_gen, 7) def test_no_n_provided_all_n_with_exp_prop_short_run(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 13 + self.experiment._design.concurrency_limit = 13 self.sobol_generation_node._trial_type = Keys.SHORT_RUN num_to_gen = NodeInputConstructors.ALL_N( previous_node=None, @@ -233,7 +233,7 @@ def test_no_n_provided_repeat_n(self) -> None: self.assertEqual(num_to_gen, 1) def test_no_n_provided_repeat_n_with_exp_prop(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 18 + self.experiment._design.concurrency_limit = 18 num_to_gen = NodeInputConstructors.REPEAT_N( previous_node=None, next_node=self.sobol_generation_node, @@ -243,7 +243,7 @@ def test_no_n_provided_repeat_n_with_exp_prop(self) -> None: self.assertEqual(num_to_gen, 2) def test_no_n_provided_repeat_n_with_exp_prop_long_run(self) -> None: - self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 18 + self.experiment._design.concurrency_limit = 18 self.sobol_generation_node._trial_type = Keys.SHORT_RUN num_to_gen = NodeInputConstructors.REPEAT_N( previous_node=None, diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 03747b127d8..226445b460d 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -1202,7 +1202,7 @@ def test_gen_for_multiple_uses_total_concurrent_arms_for_a_default( self.sobol_node._transition_criteria = [] gs = GenerationStrategy(nodes=[self.sobol_node], name="test") gs.experiment = exp - exp._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS.value] = 3 + exp._design.concurrency_limit = 3 grs = gs.gen(exp, num_trials=2) self.assertEqual(len(grs), 2) for gr_list in grs: diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index a56357a3b4c..19a07cd10a3 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -1645,7 +1645,7 @@ def _validate_options(self, options: OrchestratorOptions) -> None: ) def _get_max_pending_trials(self) -> int: - """Returns the maximum number of pending trials specified in the options, or + """Returns the maximum number of concurrent trials specified in the options, or zero, if the failure rate limit has been exceeded at any point during the optimization. """ @@ -1690,8 +1690,9 @@ def _prepare_trials( max_pending_upper_bound = max_pending_trials - num_pending_trials if max_pending_upper_bound < 1: self.logger.debug( - f"`max_pending_trials={max_pending_trials}` and {num_pending_trials} " - "trials are currently pending; not initiating any additional trials." + f"`max_pending_trials={max_pending_trials}` and " + f"{num_pending_trials} trials are currently pending; " + "not initiating any additional trials." ) return [], [] n = max_pending_upper_bound if n == -1 else min(max_pending_upper_bound, n) diff --git a/ax/orchestration/orchestrator_options.py b/ax/orchestration/orchestrator_options.py index 70ab3b04cf3..184db46fedc 100644 --- a/ax/orchestration/orchestrator_options.py +++ b/ax/orchestration/orchestrator_options.py @@ -90,7 +90,7 @@ class OrchestratorOptions: deployment. The size of the groups will be determined as the minimum of ``self.poll_available_capacity()`` and the number of generator runs that the generation strategy is able to produce - without more data or reaching its allowed max paralellism limit. + without more data or reaching its allowed max concurrency limit. debug_log_run_metadata: Whether to log run_metadata for debugging purposes. early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines whether a trial should be stopped given the current state of diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index cb49ed1e39c..df9ef38d9bb 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -836,39 +836,39 @@ def get_trials_data_frame(self) -> pd.DataFrame: """ return self.experiment.to_df() - def get_max_parallelism(self) -> list[tuple[int, int]]: - """Retrieves maximum number of trials that can be scheduled in parallel + def get_max_concurrency(self) -> list[tuple[int, int]]: + """Retrieves maximum number of trials that can be scheduled concurrently at different stages of optimization. Some optimization algorithms profit significantly from sequential optimization (i.e. suggest a few points, get updated with data for them, repeat, see https://ax.dev/docs/bayesopt.html). - Parallelism setting indicates how many trials should be running simulteneously + Concurrency setting indicates how many trials should be running simultaneously (generated, but not yet completed with data). The output of this method is mapping of form - {num_trials -> max_parallelism_setting}, where the max_parallelism_setting - is used for num_trials trials. If max_parallelism_setting is -1, as - many of the trials can be ran in parallel, as necessary. If num_trials - in a tuple is -1, then the corresponding max_parallelism_setting + {num_trials -> max_concurrency_setting}, where the max_concurrency_setting + is used for num_trials trials. If max_concurrency_setting is -1, as + many of the trials can be ran concurrently, as necessary. If num_trials + in a tuple is -1, then the corresponding max_concurrency_setting should be used for all subsequent trials. For example, if the returned list is [(5, -1), (12, 6), (-1, 3)], - the schedule could be: run 5 trials with any parallelism, run 6 trials in - parallel twice, run 3 trials in parallel for as long as needed. Here, + the schedule could be: run 5 trials with any concurrency, run 6 trials + concurrently twice, run 3 trials concurrently for as long as needed. Here, 'running' a trial means obtaining a next trial from `AxClient` through get_next_trials and completing it with data when available. Returns: - Mapping of form {num_trials -> max_parallelism_setting}. + Mapping of form {num_trials -> max_concurrency_setting}. """ - parallelism_settings = [] + concurrency_settings = [] for node in self.generation_strategy._nodes: - # Extract max_parallelism from MaxGenerationParallelism criterion - max_parallelism = None + # Extract max_concurrency from MaxGenerationParallelism criterion + max_concurrency = None for tc in node.transition_criteria: if isinstance(tc, MaxGenerationParallelism): - max_parallelism = tc.threshold + max_concurrency = tc.threshold break # Try to get num_trials from the node. If there's no MinTrials # criterion (unlimited trials), num_trials will raise UserInputError. @@ -877,13 +877,23 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: num_trials = node.num_trials except UserInputError: num_trials = -1 - parallelism_settings.append( + concurrency_settings.append( ( num_trials, - max_parallelism if max_parallelism is not None else num_trials, + max_concurrency if max_concurrency is not None else num_trials, ) ) - return parallelism_settings + return concurrency_settings + + def get_max_parallelism(self) -> list[tuple[int, int]]: + """Deprecated. Use `get_max_concurrency` instead.""" + warnings.warn( + "`get_max_parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `get_max_concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.get_max_concurrency() def get_optimization_trace( self, objective_optimum: float | None = None diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index afe81f05d01..105bc15f0ea 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -50,7 +50,7 @@ UserInputError, ) from ax.exceptions.generation_strategy import MaxParallelismReachedException -from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM +from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_CONCURRENCY from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, @@ -511,7 +511,7 @@ def test_default_generation_strategy_continuous(self) -> None: if i < 5: self.assertEqual(gen_limit, 5 - i) else: - self.assertEqual(gen_limit, DEFAULT_BAYESIAN_PARALLELISM) + self.assertEqual(gen_limit, DEFAULT_BAYESIAN_CONCURRENCY) parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( @@ -1616,14 +1616,14 @@ def test_keep_generating_without_data(self) -> None: self.assertTrue(len(node0_min_trials) > 0) self.assertFalse(node0_min_trials[0].block_gen_if_met) - # Check that max_parallelism is None by verifying no MaxGenerationParallelism + # Check that max_concurrency is None by verifying no MaxGenerationParallelism # criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in ax_client.generation_strategy._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertEqual(len(node1_max_parallelism), 0) + self.assertEqual(len(node1_max_concurrency), 0) for _ in range(10): ax_client.get_next_trial() @@ -1939,17 +1939,17 @@ def test_relative_oc_without_sq(self) -> None: def test_recommended_parallelism(self) -> None: ax_client = AxClient() with self.assertRaisesRegex(AssertionError, "No generation strategy"): - ax_client.get_max_parallelism() + ax_client.get_max_concurrency() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)]) + self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)]) self.assertEqual( run_trials_using_recommended_parallelism( - ax_client, ax_client.get_max_parallelism(), 20 + ax_client, ax_client.get_max_concurrency(), 20 ), 0, ) @@ -2872,7 +2872,7 @@ def test_estimate_early_stopping_savings(self) -> None: self.assertEqual(ax_client.estimate_early_stopping_savings(), 0) - def test_max_parallelism_exception_when_early_stopping(self) -> None: + def test_max_concurrency_exception_when_early_stopping(self) -> None: ax_client = AxClient() ax_client.create_experiment( parameters=[ diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 87b6b44c045..d6b5e248123 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -16,6 +16,7 @@ from ax.core.auxiliary import AuxiliaryExperiment from ax.core.batch_trial import BatchTrial from ax.core.data import Data +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -76,6 +77,13 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: """Convert Ax experiment to a dictionary.""" + # Serialize ExperimentDesign into properties + properties = { + **experiment._properties, + EXPERIMENT_DESIGN_KEY: { + "concurrency_limit": experiment.design.concurrency_limit, + }, + } return { "__type": experiment.__class__.__name__, "name": experiment._name, @@ -90,7 +98,7 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: "trials": experiment.trials, "is_test": experiment.is_test, "data_by_trial": data_to_data_by_trial(data=experiment.data), - "properties": experiment._properties, + "properties": properties, "_trial_type_to_runner": experiment._trial_type_to_runner, } diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 18156eb7beb..fc32660efd2 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -1660,6 +1660,60 @@ def test_experiment_with_pruning_target_json_roundtrip(self) -> None: ).pruning_target_parameterization, ) + def test_experiment_design_json_roundtrip(self) -> None: + """Test that ExperimentDesign is preserved through JSON serialization.""" + # Setup: create experiment and set concurrency_limit + experiment = get_branin_experiment() + experiment.design.concurrency_limit = 42 + + # Execute: save and load experiment through JSON + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: + save_experiment( + experiment, + f.name, + encoder_registry=CORE_ENCODER_REGISTRY, + class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY, + ) + loaded_experiment = load_experiment( + f.name, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + + # Cleanup + os.remove(f.name) + + # Assert: confirm ExperimentDesign is preserved + self.assertEqual(experiment, loaded_experiment) + self.assertEqual(loaded_experiment.design.concurrency_limit, 42) + + def test_experiment_design_none_concurrency_json_roundtrip(self) -> None: + """Test that ExperimentDesign with None concurrency_limit is preserved.""" + # Setup: create experiment with default (None) concurrency_limit + experiment = get_branin_experiment() + self.assertIsNone(experiment.design.concurrency_limit) + + # Execute: save and load experiment through JSON + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: + save_experiment( + experiment, + f.name, + encoder_registry=CORE_ENCODER_REGISTRY, + class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY, + ) + loaded_experiment = load_experiment( + f.name, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + + # Cleanup + os.remove(f.name) + + # Assert: confirm ExperimentDesign is preserved with None + self.assertEqual(experiment, loaded_experiment) + self.assertIsNone(loaded_experiment.design.concurrency_limit) + def test_multi_objective_from_json_warning(self) -> None: objectives = [get_objective()] diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 87bf58cb88f..4cd25098d9d 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -27,6 +27,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import DataType from ax.core.experiment import Experiment +from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -228,6 +229,12 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: elif experiment.runner: runners.append(self.runner_to_sqa(none_throws(experiment.runner))) properties = experiment._properties.copy() + + properties[EXPERIMENT_DESIGN_KEY] = { + attribute: getattr(experiment.design, attribute) + for attribute in list(experiment.design.__dict__.keys()) + } + if ( oc := experiment.optimization_config ) is not None and oc.pruning_target_parameterization is not None: diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 30a915c3a07..72c2b7aa96c 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -424,6 +424,33 @@ def test_arm_parameter_values_cast_to_parameter_type(self) -> None: self.assertEqual(none_throws(status_quo).parameters["y"], 0.5) self.assertIs(type(none_throws(status_quo).parameters["y"]), float) + def test_experiment_design_sqa_roundtrip(self) -> None: + """Test that ExperimentDesign is preserved through SQA serialization.""" + # Create experiment and set concurrency_limit + experiment = get_experiment_with_batch_trial() + experiment.design.concurrency_limit = 42 + + # Save and load experiment through SQA + save_experiment(experiment) + loaded_experiment = load_experiment(experiment.name) + + # Verify ExperimentDesign, w/ concurrency value 42 is preserved + self.assertEqual(loaded_experiment, experiment) + self.assertEqual(loaded_experiment.design.concurrency_limit, 42) + + # Repeat process for none value for concurrency_limit + experiment = get_experiment_with_batch_trial() + experiment.name = "experiment_design_none_concurrency_test" + self.assertIsNone(experiment.design.concurrency_limit) + + # Save and load experiment through SQA + save_experiment(experiment) + loaded_experiment = load_experiment(experiment.name) + + # Verify ExperimentDesign with None is preserved + self.assertEqual(loaded_experiment, experiment) + self.assertIsNone(loaded_experiment.design.concurrency_limit) + def test_saving_and_loading_experiment_with_aux_exp(self) -> None: aux_experiment = Experiment( name="test_aux_exp_in_SQAStoreTest", diff --git a/ax/utils/common/complexity_utils.py b/ax/utils/common/complexity_utils.py index a0edc391df2..9a3028ca76e 100644 --- a/ax/utils/common/complexity_utils.py +++ b/ax/utils/common/complexity_utils.py @@ -111,7 +111,7 @@ class OptimizationSummary: is True). tolerated_trial_failure_rate: Maximum tolerated trial failure rate (should be <= 0.9). - max_pending_trials: Maximum number of pending trials. + max_pending_trials: Maximum number of concurrent trials. min_failed_trials_for_failure_rate_check: Minimum failed trials before failure rate is checked. non_default_advanced_options: Whether non-default advanced options are set.