diff --git a/plugboard-schemas/plugboard_schemas/tune.py b/plugboard-schemas/plugboard_schemas/tune.py index 7018b6f9..76580c86 100644 --- a/plugboard-schemas/plugboard_schemas/tune.py +++ b/plugboard-schemas/plugboard_schemas/tune.py @@ -23,12 +23,16 @@ class OptunaSpec(PlugboardBaseModel): the `ProcessSpec` will be passed to it. study_name: Optional; The name of the study. storage: Optional; The storage URI to save the optimisation results to. + points_to_evaluate: Optional; A list of initial parameter configurations to evaluate + first. Each entry is a dict mapping parameter full names to values. Useful for + providing a warm start when exploring large or heavily constrained search spaces. """ type: _t.Literal["ray.tune.search.optuna.OptunaSearch"] = "ray.tune.search.optuna.OptunaSearch" space: str | None = None study_name: str | None = None storage: str | None = None + points_to_evaluate: list[dict[str, _t.Any]] | None = None class BaseFieldSpec(PlugboardBaseModel, ABC): diff --git a/plugboard/exceptions/__init__.py b/plugboard/exceptions/__init__.py index 1748b603..499a7a05 100644 --- a/plugboard/exceptions/__init__.py +++ b/plugboard/exceptions/__init__.py @@ -98,9 +98,18 @@ class ValidationError(Exception): class ConstraintError(Exception): - """Raised when a constraint is violated.""" - - pass + """Raised when a constraint is violated. + + Args: + *args: Standard exception arguments. + objective_value: Optional; A custom value to assign to the objective when this constraint + is violated. If not provided, the tuner will assign plus or minus infinity depending + on the optimisation direction. + """ + + def __init__(self, *args: object, objective_value: float | None = None) -> None: + super().__init__(*args) + self.objective_value = objective_value class ProcessStatusError(Exception): diff --git a/plugboard/tune/tune.py b/plugboard/tune/tune.py index 827643f5..609a8180 100644 --- a/plugboard/tune/tune.py +++ b/plugboard/tune/tune.py @@ -288,14 +288,21 @@ def fn(config: dict[str, _t.Any]) -> dict[str, _t.Any]: # pragma: no cover result = { obj.full_name: self._get_objective(process, obj) for obj in self._objective } - except* ConstraintError as e: + except* ConstraintError as eg: modes = self._mode if isinstance(self._mode, list) else [self._mode] self._logger.warning( "Constraint violated during optimisation, stopping early", - constraint_error=str(e), + constraint_error=str(eg), + ) + first_exc = ( + _t.cast(ConstraintError, eg.exceptions[0]) if eg.exceptions else None ) result = { - obj.full_name: math.inf if mode == "min" else -math.inf + obj.full_name: ( + first_exc.objective_value + if first_exc is not None and first_exc.objective_value is not None + else (math.inf if mode == "min" else -math.inf) + ) for obj, mode in zip(self._objective, modes) } diff --git a/tests/integration/test_tuner.py b/tests/integration/test_tuner.py index 1ac2199d..bba81b82 100644 --- a/tests/integration/test_tuner.py +++ b/tests/integration/test_tuner.py @@ -34,6 +34,16 @@ async def step(self) -> None: await super().step() +class ConstrainedBWithObjectiveValue(B): + """Component with a constraint that provides a custom objective value.""" + + async def step(self) -> None: + """Override step to apply a constraint with a custom objective value.""" + if self.in_1 > 10: + raise ConstraintError("Input must not be greater than 10", objective_value=0.0) + await super().step() + + class DynamicListComponent(ComponentTestHelper): """Component with a dynamic list parameter for tuning.""" @@ -294,6 +304,55 @@ async def test_tune_with_constraint(config: dict, ray_ctx: None) -> None: ) +@pytest.mark.tuner +@pytest.mark.asyncio +async def test_tune_with_constraint_objective_value(config: dict, ray_ctx: None) -> None: + """Tests that a ConstraintError with objective_value uses that value instead of infinity.""" + spec = ConfigSpec.model_validate(config) + process_spec = spec.plugboard.process + # Replace component B with a version that provides a custom objective value on constraint + process_spec.args.components[ + 1 + ].type = "tests.integration.test_tuner.ConstrainedBWithObjectiveValue" + tuner = Tuner( + objective=ObjectiveSpec( + object_type="component", + object_name="c", + field_type="field", + field_name="in_1", + ), + parameters=[ + IntParameterSpec( + object_type="component", + object_name="a", + field_type="arg", + field_name="iters", + lower=5, + upper=15, + ) + ], + num_samples=12, + mode="max", + max_concurrent=2, + algorithm=OptunaSpec(), + ) + best_result = tuner.run( + spec=process_spec, + ) + result = tuner.result_grid + # There must be no failed trials + assert not any(t.error for t in result) + # Optimum must be at or below 10 (constraint threshold) + assert best_result.metrics["component.c.field.in_1"] <= 10 + # When constraint is violated the custom objective_value (0.0) must be used, not -inf + # The constraint raises when in_1 > 10; in_1 = iters - 1, so iters > 11 violates it + assert all( + t.metrics["component.c.field.in_1"] == 0.0 + for t in result + if t.config["component.a.arg.iters"] > 11 + ) + + @pytest.mark.tuner @pytest.mark.asyncio @pytest.mark.parametrize("space_func", [custom_space, custom_space_with_process_spec]) diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 4cb0e1a0..a8271795 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -8,6 +8,7 @@ import pytest from ray.tune.search.optuna import OptunaSearch +from plugboard.exceptions import ConstraintError from plugboard.schemas import ( CategoricalParameterSpec, ConfigSpec, @@ -131,3 +132,58 @@ def test_optuna_storage_uri_conversion(temp_dir: str) -> None: tuner.run(spec=MagicMock()) passed_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg assert isinstance(passed_alg, OptunaSearch) + + +def test_optuna_points_to_evaluate(config: dict) -> None: + """Test that points_to_evaluate is passed through to the OptunaSearch algorithm.""" + spec = ConfigSpec.model_validate(config) + process_spec = spec.plugboard.process + points = [{"component.a.arg.x": 7, "component.a.arg.y": 0.3}] + tuner = Tuner( + objective=ObjectiveSpec( + object_type="component", + object_name="c", + field_type="field", + field_name="in_1", + ), + parameters=[ + IntParameterSpec( + object_type="component", + object_name="a", + field_type="arg", + field_name="x", + lower=6, + upper=8, + ), + FloatParameterSpec( + object_type="component", + object_name="a", + field_type="arg", + field_name="y", + lower=0.1, + upper=0.5, + ), + ], + num_samples=3, + mode="max", + algorithm=OptunaSpec(points_to_evaluate=points), + ) + with patch("ray.tune.Tuner") as mock_tuner_cls: + tuner.run(spec=process_spec) + search_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg + # The underlying OptunaSearch must have received the points_to_evaluate + assert isinstance(search_alg, OptunaSearch) + assert search_alg._points_to_evaluate == points + + +def test_constraint_error_objective_value() -> None: + """Test that ConstraintError stores an optional objective_value.""" + # Default (no objective_value): backward compatible usage + err = ConstraintError("constraint violated") + assert err.objective_value is None + assert str(err) == "constraint violated" + + # With objective_value keyword argument + err_with_value = ConstraintError("constraint violated", objective_value=5.0) + assert err_with_value.objective_value == 5.0 + assert str(err_with_value) == "constraint violated"