Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[6, 8, 9]

import re
import warnings
Expand Down Expand Up @@ -281,13 +282,16 @@ def _init_experiment_from_sqa(
)

return Experiment(
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
name=experiment_sqa.name,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
description=experiment_sqa.description,
search_space=search_space,
optimization_config=opt_config,
tracking_metrics=all_metrics,
runner=runner,
status_quo=status_quo,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
is_test=experiment_sqa.is_test,
properties=properties,
auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose,
Expand Down Expand Up @@ -333,11 +337,13 @@ def _init_mt_experiment_from_sqa(
)

default_trial_type = none_throws(experiment_sqa.default_trial_type)
# pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str.
trial_type_to_runner: dict[str, Runner | None] = {
none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner)
for sqa_runner in experiment_sqa.runners
}
if len(trial_type_to_runner) == 0:
# pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str.
trial_type_to_runner = {default_trial_type: None}
trial_types_with_metrics = {
metric.trial_type
Expand All @@ -347,13 +353,18 @@ def _init_mt_experiment_from_sqa(
# trial_type_to_runner is instantiated to map all trial types to None,
# so the trial types are associated with the experiment. This is
# important for adding metrics.
# pyre-ignore[6]: SA 2.0 Column[T] keys vs str keys.
trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics))

experiment = MultiTypeExperiment(
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
name=experiment_sqa.name,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
description=experiment_sqa.description,
search_space=search_space,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
default_trial_type=default_trial_type,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
default_runner=trial_type_to_runner.get(default_trial_type),
optimization_config=opt_config,
status_quo=status_quo,
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[6]

from logging import Logger
from typing import cast
Expand Down
3 changes: 3 additions & 0 deletions ax/storage/sqa_store/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class JSONEncodedLongText(JSONEncodedObject):

# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedList: TypeDecorator = MutableList.as_mutable(JSONEncodedObject)
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedObject)
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedText)
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedLongTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedLongText)
1 change: 1 addition & 0 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[6, 8]

import logging
from collections.abc import Mapping
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/reduced_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[24]


from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun, SQATrial
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[6, 8, 9]

import os
from collections.abc import Callable, Generator, Sequence
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[8]

from __future__ import annotations

Expand Down
15 changes: 15 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,10 +1841,12 @@ def test_parameter_validation(self) -> None:
with session_scope() as session:
session.add(sqa_parameter)

# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_parameter.experiment_id = 0
with session_scope() as session:
session.add(sqa_parameter)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_parameter.generator_run_id = 0
with session_scope() as session:
session.add(sqa_parameter)
Expand All @@ -1858,6 +1860,7 @@ def test_parameter_validation(self) -> None:
with session_scope() as session:
session.add(sqa_parameter)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_parameter.experiment_id = 0
with session_scope() as session:
session.add(sqa_parameter)
Expand Down Expand Up @@ -1907,10 +1910,12 @@ def test_parameter_constraint_validation(self) -> None:
with session_scope() as session:
session.add(sqa_parameter_constraint)

# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_parameter_constraint.experiment_id = 0
with session_scope() as session:
session.add(sqa_parameter_constraint)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_parameter_constraint.generator_run_id = 0
with session_scope() as session:
session.add(sqa_parameter_constraint)
Expand All @@ -1924,6 +1929,7 @@ def test_parameter_constraint_validation(self) -> None:
with session_scope() as session:
session.add(sqa_parameter_constraint)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_parameter_constraint.experiment_id = 0
with session_scope() as session:
session.add(sqa_parameter_constraint)
Expand Down Expand Up @@ -1961,10 +1967,12 @@ def test_metric_validation(self) -> None:
with session_scope() as session:
session.add(sqa_metric)

# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_metric.experiment_id = 0
with session_scope() as session:
session.add(sqa_metric)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_metric.generator_run_id = 0
with session_scope() as session:
session.add(sqa_metric)
Expand All @@ -1979,6 +1987,7 @@ def test_metric_validation(self) -> None:
with session_scope() as session:
session.add(sqa_metric)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_metric.experiment_id = 0
with session_scope() as session:
session.add(sqa_metric)
Expand Down Expand Up @@ -2025,13 +2034,16 @@ def test_metric_decode_failure(self) -> None:
with self.assertRaises(SQADecodeError):
self.decoder.metric_from_sqa(sqa_metric)

# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_metric.metric_type = CORE_METRIC_REGISTRY[BraninMetric]
# pyre-fixme[8]: Attribute has type `MetricIntent`; used as `str`.
sqa_metric.intent = "foobar"
with self.assertRaises(SQADecodeError):
self.decoder.metric_from_sqa(sqa_metric)

# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_metric.intent = MetricIntent.TRACKING
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_metric.properties = {}
with self.assertRaises(ValueError):
self.decoder.metric_from_sqa(sqa_metric)
Expand Down Expand Up @@ -2080,10 +2092,12 @@ def test_runner_validation(self) -> None:
with session_scope() as session:
session.add(sqa_runner)

# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_runner.experiment_id = 0
with session_scope() as session:
session.add(sqa_runner)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_runner.trial_id = 0
with session_scope() as session:
session.add(sqa_runner)
Expand All @@ -2094,6 +2108,7 @@ def test_runner_validation(self) -> None:
with session_scope() as session:
session.add(sqa_runner)
with self.assertRaises(ValueError):
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
sqa_runner.experiment_id = 0
with session_scope() as session:
session.add(sqa_runner)
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


def listens_for_multiple(
# pyre-ignore[24]: SA 2.0 requires a type param on InstrumentedAttribute.
targets: list[InstrumentedAttribute],
identifier: str,
*args: Any,
Expand Down
Loading