Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions ax/storage/sqa_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@
T = TypeVar("T")


def _bound_engine(session_factory: scoped_session) -> Engine:
"""Return the ``Engine`` bound to ``session_factory``, raising if not bound.

SA 2.0 types ``scoped_session.bind`` as ``Connection | Engine | None``. In
this module we always bind a freshly-created ``Engine`` via ``sessionmaker``,
so the runtime value is always an ``Engine``. This narrows the type for both
pyre and runtime safety.
"""
bind = session_factory.bind
if not isinstance(bind, Engine):
raise ValueError(
f"SESSION_FACTORY must be bound to an Engine, got {type(bind).__name__}."
)
return bind


class SQABase:
"""Metaclass for SQLAlchemy classes corresponding to core Ax classes."""

Expand Down Expand Up @@ -164,8 +180,7 @@ def init_engine_and_session_factory(

if SESSION_FACTORY is not None:
if force_init:
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
SESSION_FACTORY.bind.dispose()
_bound_engine(SESSION_FACTORY).dispose()
else:
return
if url is not None:
Expand Down Expand Up @@ -205,8 +220,7 @@ def init_test_engine_and_session_factory(

if SESSION_FACTORY is not None:
if force_init:
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
SESSION_FACTORY.bind.dispose()
_bound_engine(SESSION_FACTORY).dispose()
else:
return
engine = create_test_engine(path=tier_or_path, echo=echo)
Expand Down Expand Up @@ -264,8 +278,7 @@ def get_engine() -> Engine:
global SESSION_FACTORY
if SESSION_FACTORY is None:
raise ValueError("Engine must be initialized first.")
# pyre-ignore[7]: SA 2.0 bind is Union; runtime Engine.
return SESSION_FACTORY.bind
return _bound_engine(SESSION_FACTORY)


@contextmanager
Expand Down Expand Up @@ -333,6 +346,5 @@ def session_context(
finally:
# Restore the old session factory
session_factory.close()
# pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine.
session_factory.bind.dispose()
_bound_engine(session_factory).dispose()
SESSION_FACTORY = old_session
9 changes: 3 additions & 6 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,8 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa(
):
continue
aux_experiment = auxiliary_experiment_from_name(
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
experiment_name=auxiliary_experiment_sqa.source_experiment.name,
config=self.config,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
is_active=auxiliary_experiment_sqa.is_active,
reduced_state=reduced_state,
)
Expand Down Expand Up @@ -252,9 +250,7 @@ def _init_experiment_from_sqa(
raise SQADecodeError("Experiment SearchSpace cannot be None.")
status_quo = (
Arm(
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
parameters=experiment_sqa.status_quo_parameters,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
name=experiment_sqa.status_quo_name,
)
if experiment_sqa.status_quo_parameters is not None
Expand Down Expand Up @@ -323,21 +319,21 @@ def _init_mt_experiment_from_sqa(
raise SQADecodeError("Experiment SearchSpace cannot be None.")
status_quo = (
Arm(
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
parameters=experiment_sqa.status_quo_parameters,
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
name=experiment_sqa.status_quo_name,
)
if experiment_sqa.status_quo_parameters is not None
else None
)

default_trial_type = none_throws(experiment_sqa.default_trial_type)
# pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str.
trial_type_to_runner: dict[str, Runner | None] = {
none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner)
for sqa_runner in experiment_sqa.runners
}
if len(trial_type_to_runner) == 0:
# pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str.
trial_type_to_runner = {default_trial_type: None}
trial_types_with_metrics = {
metric.trial_type
Expand All @@ -347,6 +343,7 @@ def _init_mt_experiment_from_sqa(
# trial_type_to_runner is instantiated to map all trial types to None,
# so the trial types are associated with the experiment. This is
# important for adding metrics.
# pyre-ignore[6]: SA 2.0 Column[T] keys vs str keys.
trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics))

experiment = MultiTypeExperiment(
Expand Down
7 changes: 7 additions & 0 deletions ax/storage/sqa_store/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,15 @@ class JSONEncodedLongText(JSONEncodedObject):
impl = Text(LONGTEXT_BYTES)


# `Mutable*.as_mutable()` returns a `TypeEngine` subclass per SA 2.0 stubs.
# Cannot annotate as `TypeEngine[Any]` because SA 1.4's `TypeEngine` is not a
# Generic class (`type 'TypeEngine' is not subscriptable` at runtime under 1.4).
# Keep `TypeDecorator` and suppress the SA 2.0 type-stub mismatch on each line.
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedList: TypeDecorator = MutableList.as_mutable(JSONEncodedObject)
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedObject)
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedText)
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
JSONEncodedLongTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedLongText)
4 changes: 2 additions & 2 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ def load_analysis_cards_by_experiment_name(
analysis_card_sqa_class.children
)

exp_sqa_class: SQAExperiment = cast(
SQAExperiment, decoder.config.class_to_sqa_class[Experiment]
exp_sqa_class = cast(
type[SQAExperiment], decoder.config.class_to_sqa_class[Experiment]
)

with session_scope() as session:
Expand Down
8 changes: 5 additions & 3 deletions ax/storage/sqa_store/reduced_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[24]
#
# SA 2.0 requires a type param on InstrumentedAttribute, but SA 1.4
# InstrumentedAttribute is not subscriptable at runtime (`type is not a generic
# class`), so we keep the bare form to preserve dual-version compatibility.


from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun, SQATrial
from sqlalchemy.orm import defaultload, strategy_options
from sqlalchemy.orm.attributes import InstrumentedAttribute


# pyre-fixme[9]: `GR_LARGE_MODEL_ATTRS` is declared as `List[InstrumentedAttribute]`
# but SQLAlchemy class attributes are typed as `Column` in stubs; they are
# `InstrumentedAttribute` instances at runtime.
GR_LARGE_MODEL_ATTRS: list[InstrumentedAttribute] = [
SQAGeneratorRun.model_kwargs,
SQAGeneratorRun.bridge_kwargs,
Expand Down
Loading