From 0edf395fd0d1110e270ccc7cb28b278ca4edb5a8 Mon Sep 17 00:00:00 2001 From: Jiawei Yang Date: Thu, 14 May 2026 16:10:21 -0700 Subject: [PATCH] Make SQA store SQLAlchemy 2.0-compatible Summary: The OSS Ax SQA store has a hard guard in with_db_settings_base.py that raises IncompatibleDependencyVersion when SQLAlchemy major > 1, disabling SQL storage entirely. This blocks SA 2.0 adoption (T163607006) and Python 3.13/3.14 (which auto-select SA 2.0.48 from third-party). Two additional SA 2.0 incompatibilities exist in OSS: defer("col_name") in load.py and reduced_state.py, which SA 2.0 rejects in favor of class-bound attribute references. This diff removes the guard and converts the string-based loader options to attribute references. Adds a dual-version Buck test target tests_sa2 via constraint_overrides plus a self-proving TestSQLAlchemyDualVersionCompat class so each target proves its constraint took effect (EXPECTED_SA_MAJOR env var asserted at runtime). Differential Revision: D104875017 --- ax/storage/sqa_store/load.py | 10 ++-- ax/storage/sqa_store/reduced_state.py | 3 +- .../tests/test_with_db_settings_base.py | 47 +++++++++++++++++++ ax/storage/sqa_store/with_db_settings_base.py | 22 ++------- 4 files changed, 57 insertions(+), 25 deletions(-) diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index b333365369a..87b9ca808a5 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -614,10 +614,12 @@ def get_generation_strategy_sqa_reduced_state( gr_sqa_class.parameter_constraints ), defaultload(gs_sqa_class.generator_runs).lazyload(gr_sqa_class.metrics), - defaultload(gs_sqa_class.generator_runs).defer("model_kwargs"), - defaultload(gs_sqa_class.generator_runs).defer("bridge_kwargs"), - defaultload(gs_sqa_class.generator_runs).defer("model_state_after_gen"), - defaultload(gs_sqa_class.generator_runs).defer("gen_metadata"), + defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.model_kwargs), + defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.bridge_kwargs), + defaultload(gs_sqa_class.generator_runs).defer( + gr_sqa_class.model_state_after_gen + ), + defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.gen_metadata), ], ) diff --git a/ax/storage/sqa_store/reduced_state.py b/ax/storage/sqa_store/reduced_state.py index 39cdf76780e..208aa5998c5 100644 --- a/ax/storage/sqa_store/reduced_state.py +++ b/ax/storage/sqa_store/reduced_state.py @@ -58,6 +58,5 @@ def get_query_options_to_defer_large_model_cols() -> list[strategy_options.Load] when loading experiment and generation strategy in reduced state. """ return [ - defaultload(SQATrial.generator_runs).defer(col.key) - for col in GR_LARGE_MODEL_ATTRS + defaultload(SQATrial.generator_runs).defer(col) for col in GR_LARGE_MODEL_ATTRS ] diff --git a/ax/storage/sqa_store/tests/test_with_db_settings_base.py b/ax/storage/sqa_store/tests/test_with_db_settings_base.py index 52f85f0c0e7..d9895104b75 100644 --- a/ax/storage/sqa_store/tests/test_with_db_settings_base.py +++ b/ax/storage/sqa_store/tests/test_with_db_settings_base.py @@ -6,15 +6,18 @@ # pyre-strict import logging +import os import random import string from unittest.mock import patch +import sqlalchemy from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment from ax.core.trial import Trial from ax.core.trial_status import TrialStatus from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.storage.sqa_store import with_db_settings_base as _wdb_module from ax.storage.sqa_store.db import init_test_engine_and_session_factory from ax.storage.sqa_store.load import ( _load_experiment, @@ -396,3 +399,47 @@ def test_try_load_generation_strategy(self) -> None: lg.output[0], ) self.assertEqual(output, generation_strategy) + + +class TestSQLAlchemyDualVersionCompat(TestCase): + """Self-proving checks that the dual-version SA 2.0 BUCK targets actually + resolved their constraint_overrides and that the SA 2.0 hard guard is gone. + + Part of the SA 2.0 dual-version migration (T163607006). + """ + + def test_module_level_dbsettings_is_defined(self) -> None: + """The SA 2.0 hard guard previously set DBSettings = None at module level + when SA major > 1, breaking WithDBSettingsBase.__init__ type checks. Now + that the guard is removed, DBSettings must always resolve to the real + type when SQLAlchemy is importable. Uses getattr because DBSettings is + conditionally defined in a try/except in with_db_settings_base. + """ + # pyre-ignore[16]: DBSettings is conditionally defined in with_db_settings_base. + module_dbsettings = getattr(_wdb_module, "DBSettings", None) + self.assertIsNotNone( + module_dbsettings, + "with_db_settings_base.DBSettings is None -- guard removal regressed", + ) + self.assertIs(module_dbsettings, DBSettings) + + def test_sa_major_matches_buck_target(self) -> None: + """When the BUCK target sets EXPECTED_SA_MAJOR, assert the runtime + SQLAlchemy major matches. Makes :tests vs :tests_sa2 self-proving. + Skipped when EXPECTED_SA_MAJOR is unset (e.g., local one-off invocations). + """ + expected_major_str = os.environ.get("EXPECTED_SA_MAJOR") + if expected_major_str is None: + self.skipTest( + "EXPECTED_SA_MAJOR not set; only enforced under the dual-version " + "BUCK targets that pin SQLAlchemy via constraint_overrides" + ) + # pyre-ignore[16]: Module `sqlalchemy` has no attribute `__version__`. + actual_version = sqlalchemy.__version__ + actual_major = int(actual_version.split(".")[0]) + self.assertEqual( + actual_major, + int(expected_major_str), + f"BUCK target expected SQLAlchemy major {expected_major_str}, " + f"got {actual_version}", + ) diff --git a/ax/storage/sqa_store/with_db_settings_base.py b/ax/storage/sqa_store/with_db_settings_base.py index 5ad5acb6c03..c86589bfea4 100644 --- a/ax/storage/sqa_store/with_db_settings_base.py +++ b/ax/storage/sqa_store/with_db_settings_base.py @@ -6,7 +6,6 @@ # pyre-strict -import re import time from collections.abc import Sequence from logging import INFO, Logger @@ -16,11 +15,7 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.runner import Runner -from ax.exceptions.core import ( - IncompatibleDependencyVersion, - ObjectNotFoundError, - UnsupportedError, -) +from ax.exceptions.core import ObjectNotFoundError, UnsupportedError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.utils.common.executils import retry_on_exception from ax.utils.common.logger import _round_floats_for_logging, get_logger @@ -37,18 +32,7 @@ from sqlalchemy import __version__ as sqa_version # pyre-fixme[16]: Module `sqlalchemy` has no attribute `__version__`. - sqa_major_version = int(none_throws(re.match(r"^\d*", sqa_version))[0]) - if sqa_major_version > 1: - msg = ( - "Ax currently requires a sqlalchemy version below 2.0. This will be " - "addressed in a future release. Disabling SQL storage in Ax for now, if " - "you would like to use SQL storage please install Ax with mysql extras " - "via `pip install ax-platform[mysql]`." - ) - - logger.warning(msg) - - raise IncompatibleDependencyVersion(msg) + logger.info(f"Ax SQL storage initialized with SQLAlchemy {sqa_version}") from ax.storage.sqa_store.db import init_engine_and_session_factory from ax.storage.sqa_store.decoder import Decoder @@ -78,7 +62,7 @@ # We retry on `OperationalError` if saving to DB. RETRY_EXCEPTION_TYPES = (OperationalError, StaleDataError) -except (ModuleNotFoundError, IncompatibleDependencyVersion, TypeError): +except (ModuleNotFoundError, TypeError): DBSettings = None TDBSettings = None Decoder = None