Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,10 +614,12 @@ def get_generation_strategy_sqa_reduced_state(
gr_sqa_class.parameter_constraints
),
defaultload(gs_sqa_class.generator_runs).lazyload(gr_sqa_class.metrics),
defaultload(gs_sqa_class.generator_runs).defer("model_kwargs"),
defaultload(gs_sqa_class.generator_runs).defer("bridge_kwargs"),
defaultload(gs_sqa_class.generator_runs).defer("model_state_after_gen"),
defaultload(gs_sqa_class.generator_runs).defer("gen_metadata"),
defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.model_kwargs),
defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.bridge_kwargs),
defaultload(gs_sqa_class.generator_runs).defer(
gr_sqa_class.model_state_after_gen
),
defaultload(gs_sqa_class.generator_runs).defer(gr_sqa_class.gen_metadata),
],
)

Expand Down
3 changes: 1 addition & 2 deletions ax/storage/sqa_store/reduced_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,5 @@ def get_query_options_to_defer_large_model_cols() -> list[strategy_options.Load]
when loading experiment and generation strategy in reduced state.
"""
return [
defaultload(SQATrial.generator_runs).defer(col.key)
for col in GR_LARGE_MODEL_ATTRS
defaultload(SQATrial.generator_runs).defer(col) for col in GR_LARGE_MODEL_ATTRS
]
47 changes: 47 additions & 0 deletions ax/storage/sqa_store/tests/test_with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

# pyre-strict
import logging
import os
import random
import string
from unittest.mock import patch

import sqlalchemy
from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.storage.sqa_store import with_db_settings_base as _wdb_module
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.storage.sqa_store.load import (
_load_experiment,
Expand Down Expand Up @@ -396,3 +399,47 @@ def test_try_load_generation_strategy(self) -> None:
lg.output[0],
)
self.assertEqual(output, generation_strategy)


class TestSQLAlchemyDualVersionCompat(TestCase):
"""Self-proving checks that the dual-version SA 2.0 BUCK targets actually
resolved their constraint_overrides and that the SA 2.0 hard guard is gone.

Part of the SA 2.0 dual-version migration (T163607006).
"""

def test_module_level_dbsettings_is_defined(self) -> None:
"""The SA 2.0 hard guard previously set DBSettings = None at module level
when SA major > 1, breaking WithDBSettingsBase.__init__ type checks. Now
that the guard is removed, DBSettings must always resolve to the real
type when SQLAlchemy is importable. Uses getattr because DBSettings is
conditionally defined in a try/except in with_db_settings_base.
"""
# pyre-ignore[16]: DBSettings is conditionally defined in with_db_settings_base.
module_dbsettings = getattr(_wdb_module, "DBSettings", None)
self.assertIsNotNone(
module_dbsettings,
"with_db_settings_base.DBSettings is None -- guard removal regressed",
)
self.assertIs(module_dbsettings, DBSettings)

def test_sa_major_matches_buck_target(self) -> None:
"""When the BUCK target sets EXPECTED_SA_MAJOR, assert the runtime
SQLAlchemy major matches. Makes :tests vs :tests_sa2 self-proving.
Skipped when EXPECTED_SA_MAJOR is unset (e.g., local one-off invocations).
"""
expected_major_str = os.environ.get("EXPECTED_SA_MAJOR")
if expected_major_str is None:
self.skipTest(
"EXPECTED_SA_MAJOR not set; only enforced under the dual-version "
"BUCK targets that pin SQLAlchemy via constraint_overrides"
)
# pyre-ignore[16]: Module `sqlalchemy` has no attribute `__version__`.
actual_version = sqlalchemy.__version__
actual_major = int(actual_version.split(".")[0])
self.assertEqual(
actual_major,
int(expected_major_str),
f"BUCK target expected SQLAlchemy major {expected_major_str}, "
f"got {actual_version}",
)
22 changes: 3 additions & 19 deletions ax/storage/sqa_store/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict

import re
import time
from collections.abc import Sequence
from logging import INFO, Logger
Expand All @@ -16,11 +15,7 @@
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.runner import Runner
from ax.exceptions.core import (
IncompatibleDependencyVersion,
ObjectNotFoundError,
UnsupportedError,
)
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import _round_floats_for_logging, get_logger
Expand All @@ -37,18 +32,7 @@
from sqlalchemy import __version__ as sqa_version

# pyre-fixme[16]: Module `sqlalchemy` has no attribute `__version__`.
sqa_major_version = int(none_throws(re.match(r"^\d*", sqa_version))[0])
if sqa_major_version > 1:
msg = (
"Ax currently requires a sqlalchemy version below 2.0. This will be "
"addressed in a future release. Disabling SQL storage in Ax for now, if "
"you would like to use SQL storage please install Ax with mysql extras "
"via `pip install ax-platform[mysql]`."
)

logger.warning(msg)

raise IncompatibleDependencyVersion(msg)
logger.info(f"Ax SQL storage initialized with SQLAlchemy {sqa_version}")

from ax.storage.sqa_store.db import init_engine_and_session_factory
from ax.storage.sqa_store.decoder import Decoder
Expand Down Expand Up @@ -78,7 +62,7 @@

# We retry on `OperationalError` if saving to DB.
RETRY_EXCEPTION_TYPES = (OperationalError, StaleDataError)
except (ModuleNotFoundError, IncompatibleDependencyVersion, TypeError):
except (ModuleNotFoundError, TypeError):
DBSettings = None
TDBSettings = None
Decoder = None
Expand Down