Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions plugboard-schemas/plugboard_schemas/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ class OptunaSpec(PlugboardBaseModel):
the `ProcessSpec` will be passed to it.
study_name: Optional; The name of the study.
storage: Optional; The storage URI to save the optimisation results to.
points_to_evaluate: Optional; A list of initial parameter configurations to evaluate
first. Each entry is a dict mapping parameter full names to values. Useful for
providing a warm start when exploring large or heavily constrained search spaces.
"""

type: _t.Literal["ray.tune.search.optuna.OptunaSearch"] = "ray.tune.search.optuna.OptunaSearch"
space: str | None = None
study_name: str | None = None
storage: str | None = None
points_to_evaluate: list[dict[str, _t.Any]] | None = None


class BaseFieldSpec(PlugboardBaseModel, ABC):
Expand Down
15 changes: 12 additions & 3 deletions plugboard/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,18 @@ class ValidationError(Exception):


class ConstraintError(Exception):
"""Raised when a constraint is violated."""

pass
"""Raised when a constraint is violated.

Args:
*args: Standard exception arguments.
objective_value: Optional; A custom value to assign to the objective when this constraint
is violated. If not provided, the tuner will assign plus or minus infinity depending
on the optimisation direction.
"""

def __init__(self, *args: object, objective_value: float | None = None) -> None:
super().__init__(*args)
self.objective_value = objective_value


class ProcessStatusError(Exception):
Expand Down
13 changes: 10 additions & 3 deletions plugboard/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,21 @@ def fn(config: dict[str, _t.Any]) -> dict[str, _t.Any]: # pragma: no cover
result = {
obj.full_name: self._get_objective(process, obj) for obj in self._objective
}
except* ConstraintError as e:
except* ConstraintError as eg:
modes = self._mode if isinstance(self._mode, list) else [self._mode]
self._logger.warning(
"Constraint violated during optimisation, stopping early",
constraint_error=str(e),
constraint_error=str(eg),
)
first_exc = (
_t.cast(ConstraintError, eg.exceptions[0]) if eg.exceptions else None
)
result = {
obj.full_name: math.inf if mode == "min" else -math.inf
obj.full_name: (
first_exc.objective_value
if first_exc is not None and first_exc.objective_value is not None
else (math.inf if mode == "min" else -math.inf)
)
for obj, mode in zip(self._objective, modes)
}

Expand Down
59 changes: 59 additions & 0 deletions tests/integration/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ async def step(self) -> None:
await super().step()


class ConstrainedBWithObjectiveValue(B):
"""Component with a constraint that provides a custom objective value."""

async def step(self) -> None:
"""Override step to apply a constraint with a custom objective value."""
if self.in_1 > 10:
raise ConstraintError("Input must not be greater than 10", objective_value=0.0)
await super().step()


class DynamicListComponent(ComponentTestHelper):
"""Component with a dynamic list parameter for tuning."""

Expand Down Expand Up @@ -294,6 +304,55 @@ async def test_tune_with_constraint(config: dict, ray_ctx: None) -> None:
)


@pytest.mark.tuner
@pytest.mark.asyncio
async def test_tune_with_constraint_objective_value(config: dict, ray_ctx: None) -> None:
"""Tests that a ConstraintError with objective_value uses that value instead of infinity."""
spec = ConfigSpec.model_validate(config)
process_spec = spec.plugboard.process
# Replace component B with a version that provides a custom objective value on constraint
process_spec.args.components[
1
].type = "tests.integration.test_tuner.ConstrainedBWithObjectiveValue"
tuner = Tuner(
objective=ObjectiveSpec(
object_type="component",
object_name="c",
field_type="field",
field_name="in_1",
),
parameters=[
IntParameterSpec(
object_type="component",
object_name="a",
field_type="arg",
field_name="iters",
lower=5,
upper=15,
)
],
num_samples=12,
mode="max",
max_concurrent=2,
algorithm=OptunaSpec(),
)
best_result = tuner.run(
spec=process_spec,
)
result = tuner.result_grid
# There must be no failed trials
assert not any(t.error for t in result)
# Optimum must be at or below 10 (constraint threshold)
assert best_result.metrics["component.c.field.in_1"] <= 10
# When constraint is violated the custom objective_value (0.0) must be used, not -inf
# The constraint raises when in_1 > 10; in_1 = iters - 1, so iters > 11 violates it
assert all(
t.metrics["component.c.field.in_1"] == 0.0
for t in result
if t.config["component.a.arg.iters"] > 11
)


@pytest.mark.tuner
@pytest.mark.asyncio
@pytest.mark.parametrize("space_func", [custom_space, custom_space_with_process_spec])
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from ray.tune.search.optuna import OptunaSearch

from plugboard.exceptions import ConstraintError
from plugboard.schemas import (
CategoricalParameterSpec,
ConfigSpec,
Expand Down Expand Up @@ -131,3 +132,58 @@ def test_optuna_storage_uri_conversion(temp_dir: str) -> None:
tuner.run(spec=MagicMock())
passed_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg
assert isinstance(passed_alg, OptunaSearch)


def test_optuna_points_to_evaluate(config: dict) -> None:
"""Test that points_to_evaluate is passed through to the OptunaSearch algorithm."""
spec = ConfigSpec.model_validate(config)
process_spec = spec.plugboard.process
points = [{"component.a.arg.x": 7, "component.a.arg.y": 0.3}]
tuner = Tuner(
objective=ObjectiveSpec(
object_type="component",
object_name="c",
field_type="field",
field_name="in_1",
),
parameters=[
IntParameterSpec(
object_type="component",
object_name="a",
field_type="arg",
field_name="x",
lower=6,
upper=8,
),
FloatParameterSpec(
object_type="component",
object_name="a",
field_type="arg",
field_name="y",
lower=0.1,
upper=0.5,
),
],
num_samples=3,
mode="max",
algorithm=OptunaSpec(points_to_evaluate=points),
)
with patch("ray.tune.Tuner") as mock_tuner_cls:
tuner.run(spec=process_spec)
search_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg
# The underlying OptunaSearch must have received the points_to_evaluate
assert isinstance(search_alg, OptunaSearch)
assert search_alg._points_to_evaluate == points


def test_constraint_error_objective_value() -> None:
"""Test that ConstraintError stores an optional objective_value."""
# Default (no objective_value): backward compatible usage
err = ConstraintError("constraint violated")
assert err.objective_value is None
assert str(err) == "constraint violated"

# With objective_value keyword argument
err_with_value = ConstraintError("constraint violated", objective_value=5.0)
assert err_with_value.objective_value == 5.0
assert str(err_with_value) == "constraint violated"
Loading