From 4cb035b7538668096330ee4a4e82bffbea5f2018 Mon Sep 17 00:00:00 2001 From: Jiawei Yang Date: Fri, 22 May 2026 00:12:58 -0700 Subject: [PATCH 1/2] Support SQLAlchemy 2.0 in FB storage extensions (#5203) Summary: The Meta-internal Ax storage extensions in ax/fb/storage/ have two SA 2.0 incompatibilities not present in the OSS surface: a raw SQL string passed to session.execute in fb sqa_store db.py (SA 2.0 requires text() wrapping), and external_store.py uses Connection.execute() for writes without explicit transaction (SA 2.0 removed implicit autocommit, so writes were silently rolling back), uses string-keyed Row indexing (SA 2.0 requires row._mapping[key]), and consumes a Result generator outside the connection context (SA 2.0 closes the Result on connection close). This diff wraps SHOW DATABASES with text(), switches _write to engine.begin() for transactional commit, migrates _decode_row to row._mapping access, and materializes the read_raw_data result list inside the with conn block. Adds tests_sa2 dual-version Buck targets for fb sqa_store, fb external_store, and fb prod_tests, plus a SQLAlchemy2CompatTest smoke test that exercises the libfb.py.db_locator -> creator -> engine -> session -> SELECT 1 path and asserts EXPECTED_SA_MAJOR. Reviewed By: mgarrard, yangjoanna Differential Revision: D104875016 --- ax/storage/sqa_store/decoder.py | 11 +++++++++++ ax/storage/sqa_store/delete.py | 1 + ax/storage/sqa_store/json.py | 3 +++ ax/storage/sqa_store/load.py | 1 + ax/storage/sqa_store/reduced_state.py | 1 + ax/storage/sqa_store/save.py | 1 + ax/storage/sqa_store/sqa_classes.py | 1 + ax/storage/sqa_store/tests/test_sqa_store.py | 15 +++++++++++++++ ax/storage/sqa_store/validation.py | 1 + 9 files changed, 35 insertions(+) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 2609cfac63e..15034a07f45 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[6, 8, 9] import re import warnings @@ -281,13 +282,16 @@ def _init_experiment_from_sqa( ) return Experiment( + # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. name=experiment_sqa.name, + # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. description=experiment_sqa.description, search_space=search_space, optimization_config=opt_config, tracking_metrics=all_metrics, runner=runner, status_quo=status_quo, + # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. is_test=experiment_sqa.is_test, properties=properties, auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, @@ -333,11 +337,13 @@ def _init_mt_experiment_from_sqa( ) default_trial_type = none_throws(experiment_sqa.default_trial_type) + # pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str. trial_type_to_runner: dict[str, Runner | None] = { none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } if len(trial_type_to_runner) == 0: + # pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str. trial_type_to_runner = {default_trial_type: None} trial_types_with_metrics = { metric.trial_type @@ -347,13 +353,18 @@ def _init_mt_experiment_from_sqa( # trial_type_to_runner is instantiated to map all trial types to None, # so the trial types are associated with the experiment. This is # important for adding metrics. + # pyre-ignore[6]: SA 2.0 Column[T] keys vs str keys. trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics)) experiment = MultiTypeExperiment( + # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. name=experiment_sqa.name, + # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. description=experiment_sqa.description, search_space=search_space, + # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. default_trial_type=default_trial_type, + # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. default_runner=trial_type_to_runner.get(default_trial_type), optimization_config=opt_config, status_quo=status_quo, diff --git a/ax/storage/sqa_store/delete.py b/ax/storage/sqa_store/delete.py index ebf67ceea25..290d9825d6c 100644 --- a/ax/storage/sqa_store/delete.py +++ b/ax/storage/sqa_store/delete.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[6] from logging import Logger from typing import cast diff --git a/ax/storage/sqa_store/json.py b/ax/storage/sqa_store/json.py index 258c71cc69f..21021defbaf 100644 --- a/ax/storage/sqa_store/json.py +++ b/ax/storage/sqa_store/json.py @@ -95,6 +95,9 @@ class JSONEncodedLongText(JSONEncodedObject): # pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedList: TypeDecorator = MutableList.as_mutable(JSONEncodedObject) +# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedObject) +# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedText) +# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedLongTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedLongText) diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index 87b9ca808a5..4846a764075 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[6, 8] import logging from collections.abc import Mapping diff --git a/ax/storage/sqa_store/reduced_state.py b/ax/storage/sqa_store/reduced_state.py index 208aa5998c5..77f5f737ff8 100644 --- a/ax/storage/sqa_store/reduced_state.py +++ b/ax/storage/sqa_store/reduced_state.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[24] from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun, SQATrial diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index 4711877872d..ed8e32aab4c 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[6, 8, 9] import os from collections.abc import Callable, Generator, Sequence diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index ee656e8483a..8ef3218bac4 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[8] from __future__ import annotations diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 4e1e08b7303..ed9f8c914cc 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1841,10 +1841,12 @@ def test_parameter_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter) + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter.generator_run_id = 0 with session_scope() as session: session.add(sqa_parameter) @@ -1858,6 +1860,7 @@ def test_parameter_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter) @@ -1907,10 +1910,12 @@ def test_parameter_constraint_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter_constraint) + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter_constraint.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter_constraint) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter_constraint.generator_run_id = 0 with session_scope() as session: session.add(sqa_parameter_constraint) @@ -1924,6 +1929,7 @@ def test_parameter_constraint_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter_constraint) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter_constraint.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter_constraint) @@ -1961,10 +1967,12 @@ def test_metric_validation(self) -> None: with session_scope() as session: session.add(sqa_metric) + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.experiment_id = 0 with session_scope() as session: session.add(sqa_metric) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.generator_run_id = 0 with session_scope() as session: session.add(sqa_metric) @@ -1979,6 +1987,7 @@ def test_metric_validation(self) -> None: with session_scope() as session: session.add(sqa_metric) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.experiment_id = 0 with session_scope() as session: session.add(sqa_metric) @@ -2025,13 +2034,16 @@ def test_metric_decode_failure(self) -> None: with self.assertRaises(SQADecodeError): self.decoder.metric_from_sqa(sqa_metric) + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.metric_type = CORE_METRIC_REGISTRY[BraninMetric] # pyre-fixme[8]: Attribute has type `MetricIntent`; used as `str`. sqa_metric.intent = "foobar" with self.assertRaises(SQADecodeError): self.decoder.metric_from_sqa(sqa_metric) + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.intent = MetricIntent.TRACKING + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.properties = {} with self.assertRaises(ValueError): self.decoder.metric_from_sqa(sqa_metric) @@ -2080,10 +2092,12 @@ def test_runner_validation(self) -> None: with session_scope() as session: session.add(sqa_runner) + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_runner.experiment_id = 0 with session_scope() as session: session.add(sqa_runner) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_runner.trial_id = 0 with session_scope() as session: session.add(sqa_runner) @@ -2094,6 +2108,7 @@ def test_runner_validation(self) -> None: with session_scope() as session: session.add(sqa_runner) with self.assertRaises(ValueError): + # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_runner.experiment_id = 0 with session_scope() as session: session.add(sqa_runner) diff --git a/ax/storage/sqa_store/validation.py b/ax/storage/sqa_store/validation.py index 9346ccbe6c8..1d3c65cbef2 100644 --- a/ax/storage/sqa_store/validation.py +++ b/ax/storage/sqa_store/validation.py @@ -34,6 +34,7 @@ def listens_for_multiple( + # pyre-ignore[24]: SA 2.0 requires a type param on InstrumentedAttribute. targets: list[InstrumentedAttribute], identifier: str, *args: Any, From b4a5d53b3cca3db79bb4b62e624e04a032146e56 Mon Sep 17 00:00:00 2001 From: Jiawei Yang Date: Fri, 22 May 2026 00:12:58 -0700 Subject: [PATCH 2/2] Migrate SQA declarative classes to SA 2.0 Mapped[T] + adopt SQLAlchemy 2.0 in bento_kernel (#5205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Migrates the Ax SA declarative classes from SA 1.x `Column[T]` class-body annotations to SA 2.0 native `Mapped[T]` annotations, keeping `Column(...)` as the runtime constructor (NOT `mapped_column(...)`). This is the SA 1.4-compatible bridging form that gives us the type-narrowing benefits of `Mapped[T]` at downstream callsites while keeping the OSS Ax dual-version contract — runtime works on both SA 1.4 and SA 2.0. Why not `mapped_column(...)`: `mapped_column` is SA 2.0-only. Several Meta consumers (e.g. ad ranking AutoML/AutoParamFinder targets) still pin SA 1.4 in their PACKAGE files, and the OSS Ax repo also supports SA 1.4. A pure SA 2.0 migration would `ImportError` at module load in those contexts. `Mapped[T] = Column(...)` runs identically under both versions: - SA 2.0: `Mapped` is a typed descriptor; `__get__` resolves to `T` at instance access; declarative scanner evaluates the annotation via `typing.get_type_hints` and maps the `Column` correctly. `__allow_unmapped__ = True` on `SQABase` enables this hybrid form alongside the strict `mapped_column` form. - SA 1.4: `Mapped` is importable as a class (yes — SA 1.4.17 exports it from `sqlalchemy.orm`); annotations stay as strings due to `from __future__ import annotations`; SA 1.4's declarative scanner never introspects the annotation, only the `Column(...)` RHS. Trade-off: under SA 2.0 stubs pyre sees `Column[T]` on the RHS as incompatible with `Mapped[T]` on the LHS, so each declarative class file carries a file-level `# pyre-ignore-all-errors[8]`. The benefit is paid back at every downstream callsite: `experiment_sqa.name` resolves to `str` (not `Column[str]`) for the type-checker, eliminating the cascade of `# pyre-ignore[6]: Column[T] vs plain T param` suppressions that pure-Column class declarations would require. Concretely, the cleanup diff D106016101 removes ~22 inline pyre-ignores that were previously needed because the SA 2.0 typed stubs flagged every callsite that passed `experiment_sqa.X` to a function expecting plain `X`. The migration also corrects several pre-existing annotation lies: places where the old `Column[T]` annotation was narrower than the runtime `Column(..., nullable=True)` default have been widened to `Mapped[T | None]` to match the actual schema. Nullability rule applied uniformly: source of truth is the runtime `Column(..., nullable=)` flag (with `primary_key=True` implying not-null), NOT the prior annotation. See the contract comment at the top of `ax/storage/sqa_store/sqa_classes.py` for what future edits must respect — `mapped_column`'s automatic nullable inference from `Mapped[T]` does NOT apply here because we're using `Column(...)`, so each new column MUST explicitly pass `nullable=False` (or `primary_key=True`) for `Mapped[T]` non-Optional, and either omit `nullable=` or pass `nullable=True` for `Mapped[T | None]`. Mapped import uses a `try/except ImportError` guard at module load. SA 2.0 needs `Mapped` in the module namespace at class-definition time (the declarative scanner evaluates the annotation strings); SA 1.4 doesn't introspect annotations, so silently catching the unlikely ImportError keeps the file loadable even in stripped-down SA installations. Files touched: - `fbcode/ax/storage/sqa_store/sqa_classes.py`: all 13 declarative classes migrated to `Mapped[T] = Column(...)` form. Removed prior `mapped_column` import. Added file-level `# pyre-ignore-all-errors[8]` with explanation and the future-edit contract. - `fbcode/ax/fb/storage/sqa_store/sqa_classes.py`: `SQAExperimentFB`'s 2 relationships migrated to `Mapped[list[T]] = relationship(...)`. No file-level [8] needed because `relationship()` returns Mapped-compatible under SA 2.0 stubs. `association_proxy` lines keep their existing `# pyre-ignore[8]` (association_proxy isn't a Mapped attr). - `fbcode/ax/fb/storage/sqa_store/metadata.py`: 4 classes (ExperimentOwner, ExperimentTag, ExperimentTask, ExperimentOncallRotation) migrated. `hybrid_property.expression` `Column(...)` returns intentionally NOT migrated — they're SQL expressions, not Mapped attrs. Also bundled (per the prior diff title): bumps the `bento_kernel_pts` package PACKAGE pin to SQLAlchemy 2.0 so the PTS Bento kernel adopts SA 2.0 alongside the rest of `fbcode/ax/`. Differential Revision: D105247335 --- ax/storage/sqa_store/sqa_classes.py | 374 ++++++++++--------- ax/storage/sqa_store/tests/test_sqa_store.py | 17 - 2 files changed, 201 insertions(+), 190 deletions(-) diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 8ef3218bac4..f30999ce2eb 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -6,6 +6,20 @@ # pyre-strict # pyre-ignore-all-errors[8] +# +# `Mapped[T] = Column(...)` is the SA 1.4-compatible bridging form. SA 2.0 stubs +# see `Column[T]` on the RHS as incompatible with `Mapped[T]` LHS, hence the +# file-level [8] suppression. Runtime works under both SA versions, and pyre +# still resolves `obj.attr` access correctly via the Mapped[T] descriptor at +# callsites. +# +# **CONTRACT FOR FUTURE EDITS**: when adding a column, the `Mapped[T]` annotation +# MUST match the `Column(...)` runtime nullability: +# - `Mapped[T]` → `Column(..., nullable=False)` (or `primary_key=True`) +# - `Mapped[T | None]` → `Column(...)` (default) or `Column(..., nullable=True)` +# Mapped[T] no longer drives nullability automatically here (unlike SA 2.0's +# `mapped_column`), so a mismatched declaration silently creates a column with +# the WRONG nullability and the annotation will lie. Audit explicitly. from __future__ import annotations @@ -52,6 +66,16 @@ ) from sqlalchemy.orm import backref, relationship +# Mapped is SA 2.0-only. Under SA 2.0, it MUST be importable at runtime because +# the declarative metaclass evaluates class annotations (via typing.get_type_hints +# semantics) at class-definition time. Under SA 1.4 it doesn't exist — but SA 1.4 +# also doesn't introspect mapped annotations, so silently skipping the import is +# safe (annotations stay as strings due to `from __future__ import annotations`). +try: + from sqlalchemy.orm import Mapped +except ImportError: + pass + ONLY_ONE_FIELDS = ["experiment_id", "generator_run_id"] @@ -61,80 +85,80 @@ class SQAParameter(Base): __tablename__: str = "parameter_v2" - domain_type: Column[DomainType] = Column(IntEnum(DomainType), nullable=False) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - id: Column[int] = Column(Integer, primary_key=True) - generator_run_id: Column[int | None] = Column( + domain_type: Mapped[DomainType] = Column(IntEnum(DomainType), nullable=False) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + id: Mapped[int] = Column(Integer, primary_key=True) + generator_run_id: Mapped[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - parameter_type: Column[ParameterType] = Column( + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + parameter_type: Mapped[ParameterType] = Column( IntEnum(ParameterType), nullable=False ) - is_fidelity: Column[bool | None] = Column(Boolean) - target_value: Column[TParamValue | None] = Column(JSONEncodedObject) - backfill_value: Column[TParamValue | None] = Column(JSONEncodedObject) - default_value: Column[TParamValue | None] = Column(JSONEncodedObject) + is_fidelity: Mapped[bool | None] = Column(Boolean) + target_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) + backfill_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) + default_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) # Attributes for Range Parameters - digits: Column[int | None] = Column(Integer) - log_scale: Column[bool | None] = Column(Boolean) - lower: Column[Decimal | None] = Column(Float) - upper: Column[Decimal | None] = Column(Float) + digits: Mapped[int | None] = Column(Integer) + log_scale: Mapped[bool | None] = Column(Boolean) + lower: Mapped[Decimal | None] = Column(Float) + upper: Mapped[Decimal | None] = Column(Float) # Attributes for Choice Parameters - choice_values: Column[list[TParamValue] | None] = Column(JSONEncodedList) - is_ordered: Column[bool | None] = Column(Boolean) - is_task: Column[bool | None] = Column(Boolean) - dependents: Column[dict[TParamValue, list[str]] | None] = Column(JSONEncodedObject) + choice_values: Mapped[list[TParamValue] | None] = Column(JSONEncodedList) + is_ordered: Mapped[bool | None] = Column(Boolean) + is_task: Mapped[bool | None] = Column(Boolean) + dependents: Mapped[dict[TParamValue, list[str]] | None] = Column(JSONEncodedObject) # Attributes for Fixed Parameters - fixed_value: Column[TParamValue | None] = Column(JSONEncodedObject) + fixed_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) # Attribute for Derived Parameters - expression_str: Column[str | None] = Column(String(LONGTEXT_BYTES)) + expression_str: Mapped[str | None] = Column(String(LONGTEXT_BYTES)) class SQAParameterConstraint(Base): __tablename__: str = "parameter_constraint_v2" - bound: Column[Decimal] = Column(Float, nullable=False) - constraint_dict: Column[dict[str, float]] = Column(JSONEncodedDict, nullable=False) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - id: Column[int] = Column(Integer, primary_key=True) - generator_run_id: Column[int | None] = Column( + bound: Mapped[Decimal] = Column(Float, nullable=False) + constraint_dict: Mapped[dict[str, float]] = Column(JSONEncodedDict, nullable=False) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + id: Mapped[int] = Column(Integer, primary_key=True) + generator_run_id: Mapped[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) - type: Column[IntEnum] = Column(IntEnum(ParameterConstraintType), nullable=False) + type: Mapped[IntEnum] = Column(IntEnum(ParameterConstraintType), nullable=False) class SQAMetric(Base): __tablename__: str = "metric_v2" - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - generator_run_id: Column[int | None] = Column( + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + generator_run_id: Mapped[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) - id: Column[int] = Column(Integer, primary_key=True) - lower_is_better: Column[bool | None] = Column(Boolean) - intent: Column[MetricIntent] = Column(StringEnum(MetricIntent), nullable=False) - metric_type: Column[int] = Column(Integer, nullable=False) - name: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) - signature: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) + id: Mapped[int] = Column(Integer, primary_key=True) + lower_is_better: Mapped[bool | None] = Column(Boolean) + intent: Mapped[MetricIntent] = Column(StringEnum(MetricIntent), nullable=False) + metric_type: Mapped[int] = Column(Integer, nullable=False) + name: Mapped[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) + signature: Mapped[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) # Attributes for Objectives - minimize: Column[bool | None] = Column(Boolean) + minimize: Mapped[bool | None] = Column(Boolean) # Attributes for Outcome Constraints - op: Column[ComparisonOp | None] = Column(IntEnum(ComparisonOp)) - bound: Column[Decimal | None] = Column(Float) - relative: Column[bool | None] = Column(Boolean) + op: Mapped[ComparisonOp | None] = Column(IntEnum(ComparisonOp)) + bound: Mapped[Decimal | None] = Column(Float) + relative: Mapped[bool | None] = Column(Boolean) # Multi-type Experiment attributes - trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - canonical_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - scalarized_objective_id: Column[int | None] = Column( + trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + canonical_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + scalarized_objective_id: Mapped[int | None] = Column( Integer, ForeignKey("metric_v2.id") ) @@ -142,7 +166,7 @@ class SQAMetric(Base): # of Multi/Scalarized Objective contains all children of the parent metric # join_depth argument: used for loading self-referential relationships # https://docs.sqlalchemy.org/en/13/orm/self_referential.html#configuring-self-referential-eager-loading - scalarized_objective_children_metrics: list[SQAMetric] = relationship( + scalarized_objective_children_metrics: Mapped[list[SQAMetric]] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy=True, @@ -150,92 +174,94 @@ class SQAMetric(Base): ) # Attribute only defined for the children of Scalarized Objective - scalarized_objective_weight: Column[Decimal | None] = Column(Float) - scalarized_outcome_constraint_id: Column[int | None] = Column( + scalarized_objective_weight: Mapped[Decimal | None] = Column(Float) + scalarized_outcome_constraint_id: Mapped[int | None] = Column( Integer, ForeignKey("metric_v2.id") ) - scalarized_outcome_constraint_children_metrics: list[SQAMetric] = relationship( - "SQAMetric", - cascade="all, delete-orphan", - lazy=True, - foreign_keys=[scalarized_outcome_constraint_id], + scalarized_outcome_constraint_children_metrics: Mapped[list[SQAMetric]] = ( + relationship( + "SQAMetric", + cascade="all, delete-orphan", + lazy=True, + foreign_keys=[scalarized_outcome_constraint_id], + ) ) - scalarized_outcome_constraint_weight: Column[Decimal | None] = Column(Float) + scalarized_outcome_constraint_weight: Mapped[Decimal | None] = Column(Float) class SQAArm(Base): __tablename__: str = "arm_v2" - generator_run_id: Column[int] = Column( + generator_run_id: Mapped[int] = Column( Integer, ForeignKey("generator_run_v2.id"), nullable=False ) - id: Column[int] = Column(Integer, primary_key=True) - name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - parameters: Column[TParameterization] = Column(JSONEncodedTextDict, nullable=False) - weight: Column[Decimal] = Column(Float, nullable=False, default=1.0) + id: Mapped[int] = Column(Integer, primary_key=True) + name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + parameters: Mapped[TParameterization] = Column(JSONEncodedTextDict, nullable=False) + weight: Mapped[Decimal] = Column(Float, nullable=False, default=1.0) class SQAAbandonedArm(Base): __tablename__: str = "abandoned_arm_v2" - abandoned_reason: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) - id: Column[int] = Column(Integer, primary_key=True) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - time_abandoned: Column[datetime] = Column( + abandoned_reason: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + id: Mapped[int] = Column(Integer, primary_key=True) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + time_abandoned: Mapped[datetime] = Column( IntTimestamp, nullable=False, default=datetime.now ) - trial_id: Column[int] = Column(Integer, ForeignKey("trial_v2.id"), nullable=False) + trial_id: Mapped[int] = Column(Integer, ForeignKey("trial_v2.id"), nullable=False) class SQAGeneratorRun(Base): __tablename__: str = "generator_run_v2" - best_arm_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - best_arm_parameters: Column[TParameterization | None] = Column(JSONEncodedTextDict) - best_arm_predictions: Column[TModelPredictArm | None] = Column(JSONEncodedList) - generator_run_type: Column[int | None] = Column(Integer) - id: Column[int] = Column(Integer, primary_key=True) - index: Column[int | None] = Column(Integer) - model_predictions: Column[TModelPredict | None] = Column(JSONEncodedList) - time_created: Column[datetime] = Column( + best_arm_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + best_arm_parameters: Mapped[TParameterization | None] = Column(JSONEncodedTextDict) + best_arm_predictions: Mapped[TModelPredictArm | None] = Column(JSONEncodedList) + generator_run_type: Mapped[int | None] = Column(Integer) + id: Mapped[int] = Column(Integer, primary_key=True) + index: Mapped[int | None] = Column(Integer) + model_predictions: Mapped[TModelPredict | None] = Column(JSONEncodedList) + time_created: Mapped[datetime] = Column( IntTimestamp, nullable=False, default=datetime.now ) - trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id")) - weight: Column[Decimal | None] = Column(Float) - fit_time: Column[Decimal | None] = Column(Float) - gen_time: Column[Decimal | None] = Column(Float) - model_key: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - model_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - bridge_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - gen_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - model_state_after_gen: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - generation_strategy_id: Column[int | None] = Column( + trial_id: Mapped[int | None] = Column(Integer, ForeignKey("trial_v2.id")) + weight: Mapped[Decimal | None] = Column(Float) + fit_time: Mapped[Decimal | None] = Column(Float) + gen_time: Mapped[Decimal | None] = Column(Float) + model_key: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + model_kwargs: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + bridge_kwargs: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + gen_metadata: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + model_state_after_gen: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + generation_strategy_id: Mapped[int | None] = Column( Integer, ForeignKey("generation_strategy.id") ) - candidate_metadata_by_arm_signature: Column[dict[str, Any] | None] = Column( + candidate_metadata_by_arm_signature: Mapped[dict[str, Any] | None] = Column( JSONEncodedTextDict ) - generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - suggested_experiment_status: Column[ExperimentStatus | None] = Column( + generation_node_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + suggested_experiment_status: Mapped[ExperimentStatus | None] = Column( IntEnum(ExperimentStatus), nullable=True ) # relationships # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - arms: list[SQAArm] = relationship( + arms: Mapped[list[SQAArm]] = relationship( "SQAArm", cascade="all, delete-orphan", lazy="selectin", order_by=lambda: SQAArm.id, ) - metrics: list[SQAMetric] = relationship( + metrics: Mapped[list[SQAMetric]] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) - parameters: list[SQAParameter] = relationship( + parameters: Mapped[list[SQAParameter]] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) - parameter_constraints: list[SQAParameterConstraint] = relationship( + parameter_constraints: Mapped[list[SQAParameterConstraint]] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) @@ -243,31 +269,31 @@ class SQAGeneratorRun(Base): class SQARunner(Base): __tablename__: str = "runner" - id: Column[int] = Column(Integer, primary_key=True) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - properties: Column[dict[str, Any] | None] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + properties: Mapped[dict[str, Any] | None] = Column( JSONEncodedLongTextDict, default={} ) - runner_type: Column[int] = Column(Integer, nullable=False) - trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id")) + runner_type: Mapped[int] = Column(Integer, nullable=False) + trial_id: Mapped[int | None] = Column(Integer, ForeignKey("trial_v2.id")) # Multi-type Experiment attributes - trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) class SQAData(Base): __tablename__: str = "data_v2" - id: Column[int] = Column(Integer, primary_key=True) - data_json: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False) - description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - time_created: Column[int] = Column(BigInteger, nullable=False) - trial_index: Column[int | None] = Column(Integer) - generation_strategy_id: Column[int | None] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + data_json: Mapped[str] = Column(Text(LONGTEXT_BYTES), nullable=False) + description: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + time_created: Mapped[int] = Column(BigInteger, nullable=False) + trial_index: Mapped[int | None] = Column(Integer) + generation_strategy_id: Mapped[int | None] = Column( Integer, ForeignKey("generation_strategy.id") ) - structure_metadata_json: Column[str | None] = Column( + structure_metadata_json: Mapped[str | None] = Column( Text(LONGTEXT_BYTES), nullable=True ) @@ -275,17 +301,17 @@ class SQAData(Base): class SQAGenerationStrategy(Base): __tablename__: str = "generation_strategy" - id: Column[int] = Column(Integer, primary_key=True) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - steps: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=False) - curr_index: Column[int | None] = Column(Integer, nullable=True) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - nodes: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=True) - curr_node_name: Column[str | None] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + steps: Mapped[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=False) + curr_index: Mapped[int | None] = Column(Integer, nullable=True) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + nodes: Mapped[list[dict[str, Any]] | None] = Column(JSONEncodedList, nullable=True) + curr_node_name: Mapped[str | None] = Column( String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True ) - generator_runs: list[SQAGeneratorRun] = relationship( + generator_runs: Mapped[list[SQAGeneratorRun]] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin", @@ -296,29 +322,29 @@ class SQAGenerationStrategy(Base): class SQATrial(Base): __tablename__: str = "trial_v2" - abandoned_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - failed_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - deployed_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - experiment_id: Column[int] = Column( + abandoned_reason: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + failed_reason: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + deployed_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), nullable=False ) - id: Column[int] = Column(Integer, primary_key=True) - index: Column[int] = Column(Integer, index=True, nullable=False) - is_batch: Column[bool] = Column("is_batched", Boolean, nullable=False, default=True) - num_arms_created: Column[int] = Column(Integer, nullable=False, default=0) - ttl_seconds: Column[int | None] = Column(Integer) - run_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedLongTextDict) - stop_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - status: Column[TrialStatus] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + index: Mapped[int] = Column(Integer, index=True, nullable=False) + is_batch: Mapped[bool] = Column("is_batched", Boolean, nullable=False, default=True) + num_arms_created: Mapped[int] = Column(Integer, nullable=False, default=0) + ttl_seconds: Mapped[int | None] = Column(Integer) + run_metadata: Mapped[dict[str, Any] | None] = Column(JSONEncodedLongTextDict) + stop_metadata: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + status: Mapped[TrialStatus] = Column( IntEnum(TrialStatus), nullable=False, default=TrialStatus.CANDIDATE ) - status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - time_completed: Column[datetime | None] = Column(IntTimestamp) - time_created: Column[datetime] = Column(IntTimestamp, nullable=False) - time_staged: Column[datetime | None] = Column(IntTimestamp) - time_run_started: Column[datetime | None] = Column(IntTimestamp) - trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) + status_quo_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + time_completed: Mapped[datetime | None] = Column(IntTimestamp) + time_created: Mapped[datetime] = Column(IntTimestamp, nullable=False) + time_staged: Mapped[datetime | None] = Column(IntTimestamp) + time_run_started: Mapped[datetime | None] = Column(IntTimestamp) + trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) # relationships # Trials and experiments are mutable, so the children relationships need @@ -326,13 +352,13 @@ class SQATrial(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - abandoned_arms: list[SQAAbandonedArm] = relationship( + abandoned_arms: Mapped[list[SQAAbandonedArm]] = relationship( "SQAAbandonedArm", cascade="all, delete-orphan", lazy="selectin" ) - generator_runs: list[SQAGeneratorRun] = relationship( + generator_runs: Mapped[list[SQAGeneratorRun]] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin" ) - runner: SQARunner = relationship( + runner: Mapped[SQARunner] = relationship( "SQARunner", uselist=False, cascade="all, delete-orphan", lazy=False ) @@ -340,24 +366,24 @@ class SQATrial(Base): class SQAAuxiliaryExperiment(Base): __tablename__: str = "auxiliary_experiments" - source_experiment_id: Column[int] = Column( + source_experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), primary_key=True ) - target_experiment_id: Column[int] = Column( + target_experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), primary_key=True ) - purpose: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), primary_key=True) - is_active: Column[bool] = Column(Boolean, nullable=False) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - time: Column[datetime] = Column(IntTimestamp, nullable=False, default=datetime.now) - source_experiment: SQAExperiment = relationship( + purpose: Mapped[str] = Column(String(LONG_STRING_FIELD_LENGTH), primary_key=True) + is_active: Mapped[bool] = Column(Boolean, nullable=False) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + time: Mapped[datetime] = Column(IntTimestamp, nullable=False, default=datetime.now) + source_experiment: Mapped[SQAExperiment] = relationship( "SQAExperiment", foreign_keys=[source_experiment_id], lazy="selectin", viewonly=True, innerjoin=True, ) - target_experiment: SQAExperiment = relationship( + target_experiment: Mapped[SQAExperiment] = relationship( "SQAExperiment", foreign_keys=[target_experiment_id], lazy="selectin", @@ -369,23 +395,25 @@ class SQAAuxiliaryExperiment(Base): class SQAExperiment(Base): __tablename__: str = "experiment_v2" - description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) - experiment_type: Column[int | None] = Column(Integer) - id: Column[int] = Column(Integer, primary_key=True) - is_test: Column[bool] = Column(Boolean, nullable=False, default=False) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) - status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - status_quo_parameters: Column[TParameterization | None] = Column( + description: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + experiment_type: Mapped[int | None] = Column(Integer) + id: Mapped[int] = Column(Integer, primary_key=True) + is_test: Mapped[bool] = Column(Boolean, nullable=False, default=False) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) + status_quo_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + status_quo_parameters: Mapped[TParameterization | None] = Column( JSONEncodedTextDict ) - time_created: Column[datetime] = Column(IntTimestamp, nullable=False) - status: Column[ExperimentStatus | None] = Column( + time_created: Mapped[datetime] = Column(IntTimestamp, nullable=False) + status: Mapped[ExperimentStatus | None] = Column( IntEnum(ExperimentStatus), nullable=True ) - default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True) - auxiliary_experiments_by_purpose: Column[dict[str, list[dict[str, Any]]] | None] = ( + default_trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + default_data_type: Mapped[DataType | None] = Column( + IntEnum(DataType), nullable=True + ) + auxiliary_experiments_by_purpose: Mapped[dict[str, list[dict[str, Any]]] | None] = ( Column(JSONEncodedTextDict, nullable=True, default={}) ) @@ -395,38 +423,38 @@ class SQAExperiment(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - data: list[SQAData] = relationship( + data: Mapped[list[SQAData]] = relationship( "SQAData", cascade="all, delete-orphan", lazy="selectin" ) - metrics: list[SQAMetric] = relationship( + metrics: Mapped[list[SQAMetric]] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) - parameters: list[SQAParameter] = relationship( + parameters: Mapped[list[SQAParameter]] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) - parameter_constraints: list[SQAParameterConstraint] = relationship( + parameter_constraints: Mapped[list[SQAParameterConstraint]] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) - runners: list[SQARunner] = relationship( + runners: Mapped[list[SQARunner]] = relationship( "SQARunner", cascade="all, delete-orphan", lazy=False ) - trials: list[SQATrial] = relationship( + trials: Mapped[list[SQATrial]] = relationship( "SQATrial", cascade="all, delete-orphan", lazy="selectin" ) - generation_strategy: SQAGenerationStrategy | None = relationship( + generation_strategy: Mapped[SQAGenerationStrategy | None] = relationship( "SQAGenerationStrategy", backref=backref("experiment", lazy=True), uselist=False, lazy=True, ) - auxiliary_experiments: list[SQAAuxiliaryExperiment] = relationship( + auxiliary_experiments: Mapped[list[SQAAuxiliaryExperiment]] = relationship( "SQAAuxiliaryExperiment", cascade="all, delete-orphan", lazy="selectin", foreign_keys=[SQAAuxiliaryExperiment.target_experiment_id], ) - analysis_cards: list[SQAAnalysisCard] = relationship( + analysis_cards: Mapped[list[SQAAnalysisCard]] = relationship( "SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin" ) @@ -434,35 +462,35 @@ class SQAExperiment(Base): class SQAAnalysisCard(Base): __tablename__: str = "analysis_card_v2" - id: Column[int] = Column(Integer, primary_key=True) + id: Mapped[int] = Column(Integer, primary_key=True) - experiment_id: Column[int] = Column( + experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), nullable=False ) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - timestamp: Column[datetime] = Column(IntTimestamp, nullable=False) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + timestamp: Mapped[datetime] = Column(IntTimestamp, nullable=False) - parent_id: Column[int | None] = Column( + parent_id: Mapped[int | None] = Column( Integer, ForeignKey("analysis_card_v2.id"), nullable=True, ) - order: Column[int | None] = Column(Integer, nullable=True) + order: Mapped[int | None] = Column(Integer, nullable=True) - title: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=True) - subtitle: Column[str | None] = Column(Text, nullable=True) - dataframe_json: Column[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) - blob: Column[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) - blob_annotation: Column[str | None] = Column( + title: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=True) + subtitle: Mapped[str | None] = Column(Text, nullable=True) + dataframe_json: Mapped[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) + blob: Mapped[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) + blob_annotation: Mapped[str | None] = Column( String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True ) - parent: Any = relationship( + parent: Mapped[SQAAnalysisCard | None] = relationship( "SQAAnalysisCard", back_populates="children", remote_side=[id], lazy="selectin", ) - children: list[Any] = relationship( + children: Mapped[list[SQAAnalysisCard]] = relationship( "SQAAnalysisCard", cascade="all, delete-orphan", back_populates="parent", diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index ed9f8c914cc..5f3c50540b4 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -282,12 +282,10 @@ def test_generator_run_type_validation(self) -> None: generator_run._generator_run_type = "STATUS_QUO" generator_run_sqa = self.encoder.generator_run_to_sqa(generator_run) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. generator_run_sqa.generator_run_type = 2 with self.assertRaises(SQADecodeError): self.decoder.generator_run_from_sqa(generator_run_sqa, False, False) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. generator_run_sqa.generator_run_type = 0 self.decoder.generator_run_from_sqa(generator_run_sqa, False, False) @@ -1841,12 +1839,10 @@ def test_parameter_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter.generator_run_id = 0 with session_scope() as session: session.add(sqa_parameter) @@ -1860,7 +1856,6 @@ def test_parameter_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter) @@ -1910,12 +1905,10 @@ def test_parameter_constraint_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter_constraint) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter_constraint.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter_constraint) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter_constraint.generator_run_id = 0 with session_scope() as session: session.add(sqa_parameter_constraint) @@ -1929,7 +1922,6 @@ def test_parameter_constraint_validation(self) -> None: with session_scope() as session: session.add(sqa_parameter_constraint) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_parameter_constraint.experiment_id = 0 with session_scope() as session: session.add(sqa_parameter_constraint) @@ -1967,12 +1959,10 @@ def test_metric_validation(self) -> None: with session_scope() as session: session.add(sqa_metric) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.experiment_id = 0 with session_scope() as session: session.add(sqa_metric) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.generator_run_id = 0 with session_scope() as session: session.add(sqa_metric) @@ -1987,7 +1977,6 @@ def test_metric_validation(self) -> None: with session_scope() as session: session.add(sqa_metric) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.experiment_id = 0 with session_scope() as session: session.add(sqa_metric) @@ -2034,16 +2023,13 @@ def test_metric_decode_failure(self) -> None: with self.assertRaises(SQADecodeError): self.decoder.metric_from_sqa(sqa_metric) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.metric_type = CORE_METRIC_REGISTRY[BraninMetric] # pyre-fixme[8]: Attribute has type `MetricIntent`; used as `str`. sqa_metric.intent = "foobar" with self.assertRaises(SQADecodeError): self.decoder.metric_from_sqa(sqa_metric) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.intent = MetricIntent.TRACKING - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_metric.properties = {} with self.assertRaises(ValueError): self.decoder.metric_from_sqa(sqa_metric) @@ -2092,12 +2078,10 @@ def test_runner_validation(self) -> None: with session_scope() as session: session.add(sqa_runner) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_runner.experiment_id = 0 with session_scope() as session: session.add(sqa_runner) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_runner.trial_id = 0 with session_scope() as session: session.add(sqa_runner) @@ -2108,7 +2092,6 @@ def test_runner_validation(self) -> None: with session_scope() as session: session.add(sqa_runner) with self.assertRaises(ValueError): - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. sqa_runner.experiment_id = 0 with session_scope() as session: session.add(sqa_runner)