Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ax/analysis/healthcheck/early_stopping_healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ax.early_stopping.dispatch import get_default_ess_or_none
from ax.early_stopping.experiment_replay import (
estimate_hypothetical_early_stopping_savings,
MAX_PENDING_TRIALS,
MAX_CONCURRENT_TRIALS,
MIN_SAVINGS_THRESHOLD,
)
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self,
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
min_savings_threshold: float = MIN_SAVINGS_THRESHOLD,
max_pending_trials: int = MAX_PENDING_TRIALS,
max_pending_trials: int = MAX_CONCURRENT_TRIALS,
auto_early_stopping_config: AutoEarlyStoppingConfig | None = None,
nudge_additional_info: str | None = None,
) -> None:
Expand All @@ -95,7 +95,7 @@ def __init__(
single-objective unconstrained experiments.
min_savings_threshold: Minimum savings threshold to suggest early
stopping. Default is 0.1 (10% savings).
max_pending_trials: Maximum number of pending trials for replay
max_pending_trials: Maximum number of concurrent trials for replay
orchestrator. Default is 5.
auto_early_stopping_config: A string for configuring automated early
stopping strategy.
Expand Down
22 changes: 19 additions & 3 deletions ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-strict

import json
import warnings
from collections.abc import Iterable, Sequence
from logging import Logger
from typing import Any, Literal, Self
Expand Down Expand Up @@ -43,7 +44,7 @@
BaseEarlyStoppingStrategy,
PercentileEarlyStoppingStrategy,
)
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions
from ax.service.utils.best_point_mixin import BestPointMixin
Expand Down Expand Up @@ -710,9 +711,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None:
def run_trials(
self,
max_trials: int,
parallelism: int = 1,
concurrency: int = 1,
tolerated_trial_failure_rate: float = 0.5,
initial_seconds_between_polls: int = 1,
# Deprecated argument for backwards compatibility.
parallelism: int | None = None,
) -> None:
"""
Run maximum_trials trials in a loop by creating an ephemeral Orchestrator under
Expand All @@ -721,12 +724,25 @@ def run_trials(

Saves to database on completion if ``storage_config`` is present.
"""
# Handle deprecated `parallelism` argument.
if parallelism is not None:
warnings.warn(
"`parallelism` is deprecated and will be removed in Ax 1.4. "
"Use `concurrency` instead.",
DeprecationWarning,
stacklevel=2,
)
if concurrency != 1:
raise UserInputError(
"Cannot specify both `parallelism` and `concurrency`."
)
concurrency = parallelism

orchestrator = Orchestrator(
experiment=self._experiment,
generation_strategy=self._generation_strategy_or_choose(),
options=OrchestratorOptions(
max_pending_trials=parallelism,
max_pending_trials=concurrency,
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
init_seconds_between_polls=initial_seconds_between_polls,
),
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class BenchmarkMethod(Base):
"""Benchmark method, represented in terms of Ax generation strategy (which tells us
which models to use when) and Orchestrator options (which tell us extra execution
information like maximum parallelism, early stopping configuration, etc.).
information like maximum pending trials, early stopping configuration, etc.).

Args:
name: String description. Defaults to the name of the generation strategy.
Expand Down
17 changes: 17 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import combine_data_rows_favoring_recent, Data
from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.llm_provider import LLMMessage
Expand Down Expand Up @@ -152,6 +153,17 @@ def __init__(
self._status: ExperimentStatus | None = None
self._trials: dict[int, BaseTrial] = {}
self._properties: dict[str, Any] = properties or {}
self._design: ExperimentDesign = ExperimentDesign()
# Properties is temporarily being used to serialize/deserialize experiment
# in fbcode/ax/storage/sqa_store/encoder.py.
# Eventually design will have first-class support as an experiment
# attribute
# TODO[drfreund, mpolson64]: Replace with proper storage as part of the
# refactor.
if (
design_dict := self._properties.pop(EXPERIMENT_DESIGN_KEY, None)
) is not None:
self._design.concurrency_limit = design_dict.get("concurrency_limit")

# Initialize trial type to runner mapping
self._default_trial_type = default_trial_type
Expand Down Expand Up @@ -294,6 +306,11 @@ def experiment_status_from_generator_runs(

return suggested_statuses.pop()

@property
def design(self) -> ExperimentDesign:
"""The experiment design configuration."""
return self._design

@property
def search_space(self) -> SearchSpace:
"""The search space for this experiment.
Expand Down
34 changes: 34 additions & 0 deletions ax/core/experiment_design.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

EXPERIMENT_DESIGN_KEY: str = "experiment_design"


@dataclass
class ExperimentDesign:
"""Struct that holds "experiment design" configuration: these are
experiment-level settings that pertain to "how the experiment will be
run or conducted", but are agnostic to the specific evaluation
backend, to which the trials will be deployed.

NOTE: In the future, we might treat concurrency limit as expressed
in terms of "full arm equivalents" as opposed to just "number of arms",
to cover for the multi-fidelity cases.

NOTE: in ax/storage/sqa_store/encoder.py, attributes of this class
are automatically serialized and stored in experiment.properties

Args:
concurrency_limit: Maximum number of arms to run within one or
multiple trials, in parallel. In experiments that consist of
Trials, this is equivalent to the total number of trials
that should run in parallel. In experiments with BatchTrials,
total number of arms can be spread across one or
multiple BatchTrials.
"""

concurrency_limit: int | None = None
2 changes: 1 addition & 1 deletion ax/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def poll_available_capacity(self) -> int:
artificially force this method to limit capacity; ``Orchestrator`` has other
limitations in place to limit number of trials running at once,
like the ``OrchestratorOptions.max_pending_trials`` setting, or
more granular control in the form of the `max_parallelism`
more granular control in the form of the `max_concurrency`
setting in each of the `GenerationStep`s of a `GenerationStrategy`).

Returns:
Expand Down
33 changes: 33 additions & 0 deletions ax/core/tests/test_experiment_design.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Dict

from ax.core import Experiment
from ax.core.experiment_design import EXPERIMENT_DESIGN_KEY, ExperimentDesign
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_search_space


class ExperimentDesignTest(TestCase):
"""Tests covering ExperimentDesign class and its usage in ax Experiment"""

def test_experiment_design_property(self) -> None:
"""Test that Experiment.design property returns ExperimentDesign instance."""
experiment = Experiment(
name="test",
search_space=get_branin_search_space(),
)
self.assertIsInstance(experiment.design, ExperimentDesign)
self.assertIsNone(experiment.design.concurrency_limit)

properties: Dict[str, Any] = {EXPERIMENT_DESIGN_KEY: {"concurrency_limit": 42}}
experiment = Experiment(
name="test", search_space=get_branin_search_space(), properties=properties
)
self.assertEqual(experiment.design.concurrency_limit, 42)
6 changes: 3 additions & 3 deletions ax/early_stopping/experiment_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# Constants for experiment replay
MAX_REPLAY_TRIALS: int = 50
REPLAY_NUM_POINTS_PER_CURVE: int = 20
MAX_PENDING_TRIALS: int = 5
MAX_CONCURRENT_TRIALS: int = 5
MIN_SAVINGS_THRESHOLD: float = 0.1 # 10% threshold


Expand Down Expand Up @@ -119,7 +119,7 @@ def replay_experiment(
def estimate_hypothetical_early_stopping_savings(
experiment: Experiment,
metric: Metric,
max_pending_trials: int = MAX_PENDING_TRIALS,
max_pending_trials: int = MAX_CONCURRENT_TRIALS,
) -> float:
"""Estimate hypothetical early stopping savings using experiment replay.

Expand All @@ -130,7 +130,7 @@ def estimate_hypothetical_early_stopping_savings(
Args:
experiment: The experiment to analyze.
metric: The metric to use for early stopping replay.
max_pending_trials: Maximum number of pending trials for the replay
max_pending_trials: Maximum number of concurrent trials for the replay
orchestrator. Defaults to 5.

Returns:
Expand Down
Loading
Loading