diff --git a/ax/storage/sqa_store/db.py b/ax/storage/sqa_store/db.py index 2b9c7149735..c7864a1acd6 100644 --- a/ax/storage/sqa_store/db.py +++ b/ax/storage/sqa_store/db.py @@ -42,6 +42,22 @@ T = TypeVar("T") +def _bound_engine(session_factory: scoped_session) -> Engine: + """Return the ``Engine`` bound to ``session_factory``, raising if not bound. + + SA 2.0 types ``scoped_session.bind`` as ``Connection | Engine | None``. In + this module we always bind a freshly-created ``Engine`` via ``sessionmaker``, + so the runtime value is always an ``Engine``. This narrows the type for both + pyre and runtime safety. + """ + bind = session_factory.bind + if not isinstance(bind, Engine): + raise ValueError( + f"SESSION_FACTORY must be bound to an Engine, got {type(bind).__name__}." + ) + return bind + + class SQABase: """Metaclass for SQLAlchemy classes corresponding to core Ax classes.""" @@ -164,8 +180,7 @@ def init_engine_and_session_factory( if SESSION_FACTORY is not None: if force_init: - # pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine. - SESSION_FACTORY.bind.dispose() + _bound_engine(SESSION_FACTORY).dispose() else: return if url is not None: @@ -205,8 +220,7 @@ def init_test_engine_and_session_factory( if SESSION_FACTORY is not None: if force_init: - # pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine. - SESSION_FACTORY.bind.dispose() + _bound_engine(SESSION_FACTORY).dispose() else: return engine = create_test_engine(path=tier_or_path, echo=echo) @@ -264,8 +278,7 @@ def get_engine() -> Engine: global SESSION_FACTORY if SESSION_FACTORY is None: raise ValueError("Engine must be initialized first.") - # pyre-ignore[7]: SA 2.0 bind is Union; runtime Engine. - return SESSION_FACTORY.bind + return _bound_engine(SESSION_FACTORY) @contextmanager @@ -333,6 +346,5 @@ def session_context( finally: # Restore the old session factory session_factory.close() - # pyre-ignore[16]: SA 2.0 bind is Union; runtime Engine. - session_factory.bind.dispose() + _bound_engine(session_factory).dispose() SESSION_FACTORY = old_session diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 2609cfac63e..d564c764765 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -194,10 +194,8 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa( ): continue aux_experiment = auxiliary_experiment_from_name( - # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. experiment_name=auxiliary_experiment_sqa.source_experiment.name, config=self.config, - # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. is_active=auxiliary_experiment_sqa.is_active, reduced_state=reduced_state, ) @@ -252,9 +250,7 @@ def _init_experiment_from_sqa( raise SQADecodeError("Experiment SearchSpace cannot be None.") status_quo = ( Arm( - # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. parameters=experiment_sqa.status_quo_parameters, - # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. name=experiment_sqa.status_quo_name, ) if experiment_sqa.status_quo_parameters is not None @@ -323,9 +319,7 @@ def _init_mt_experiment_from_sqa( raise SQADecodeError("Experiment SearchSpace cannot be None.") status_quo = ( Arm( - # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. parameters=experiment_sqa.status_quo_parameters, - # pyre-ignore[6]: SA 2.0 Column[T] vs plain T param. name=experiment_sqa.status_quo_name, ) if experiment_sqa.status_quo_parameters is not None @@ -333,11 +327,13 @@ def _init_mt_experiment_from_sqa( ) default_trial_type = none_throws(experiment_sqa.default_trial_type) + # pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str. trial_type_to_runner: dict[str, Runner | None] = { none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } if len(trial_type_to_runner) == 0: + # pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str. trial_type_to_runner = {default_trial_type: None} trial_types_with_metrics = { metric.trial_type @@ -347,6 +343,7 @@ def _init_mt_experiment_from_sqa( # trial_type_to_runner is instantiated to map all trial types to None, # so the trial types are associated with the experiment. This is # important for adding metrics. + # pyre-ignore[6]: SA 2.0 Column[T] keys vs str keys. trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics)) experiment = MultiTypeExperiment( diff --git a/ax/storage/sqa_store/json.py b/ax/storage/sqa_store/json.py index 258c71cc69f..630ece26ff9 100644 --- a/ax/storage/sqa_store/json.py +++ b/ax/storage/sqa_store/json.py @@ -93,8 +93,15 @@ class JSONEncodedLongText(JSONEncodedObject): impl = Text(LONGTEXT_BYTES) +# `Mutable*.as_mutable()` returns a `TypeEngine` subclass per SA 2.0 stubs. +# Cannot annotate as `TypeEngine[Any]` because SA 1.4's `TypeEngine` is not a +# Generic class (`type 'TypeEngine' is not subscriptable` at runtime under 1.4). +# Keep `TypeDecorator` and suppress the SA 2.0 type-stub mismatch on each line. # pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedList: TypeDecorator = MutableList.as_mutable(JSONEncodedObject) +# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedObject) +# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedText) +# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator. JSONEncodedLongTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedLongText) diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index 87b9ca808a5..b5cd204cb0d 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -714,8 +714,8 @@ def load_analysis_cards_by_experiment_name( analysis_card_sqa_class.children ) - exp_sqa_class: SQAExperiment = cast( - SQAExperiment, decoder.config.class_to_sqa_class[Experiment] + exp_sqa_class = cast( + type[SQAExperiment], decoder.config.class_to_sqa_class[Experiment] ) with session_scope() as session: diff --git a/ax/storage/sqa_store/reduced_state.py b/ax/storage/sqa_store/reduced_state.py index 208aa5998c5..727b202c41f 100644 --- a/ax/storage/sqa_store/reduced_state.py +++ b/ax/storage/sqa_store/reduced_state.py @@ -5,6 +5,11 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[24] +# +# SA 2.0 requires a type param on InstrumentedAttribute, but SA 1.4 +# InstrumentedAttribute is not subscriptable at runtime (`type is not a generic +# class`), so we keep the bare form to preserve dual-version compatibility. from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun, SQATrial @@ -12,9 +17,6 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute -# pyre-fixme[9]: `GR_LARGE_MODEL_ATTRS` is declared as `List[InstrumentedAttribute]` -# but SQLAlchemy class attributes are typed as `Column` in stubs; they are -# `InstrumentedAttribute` instances at runtime. GR_LARGE_MODEL_ATTRS: list[InstrumentedAttribute] = [ SQAGeneratorRun.model_kwargs, SQAGeneratorRun.bridge_kwargs, diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index ee656e8483a..f30999ce2eb 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -5,6 +5,21 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[8] +# +# `Mapped[T] = Column(...)` is the SA 1.4-compatible bridging form. SA 2.0 stubs +# see `Column[T]` on the RHS as incompatible with `Mapped[T]` LHS, hence the +# file-level [8] suppression. Runtime works under both SA versions, and pyre +# still resolves `obj.attr` access correctly via the Mapped[T] descriptor at +# callsites. +# +# **CONTRACT FOR FUTURE EDITS**: when adding a column, the `Mapped[T]` annotation +# MUST match the `Column(...)` runtime nullability: +# - `Mapped[T]` → `Column(..., nullable=False)` (or `primary_key=True`) +# - `Mapped[T | None]` → `Column(...)` (default) or `Column(..., nullable=True)` +# Mapped[T] no longer drives nullability automatically here (unlike SA 2.0's +# `mapped_column`), so a mismatched declaration silently creates a column with +# the WRONG nullability and the annotation will lie. Audit explicitly. from __future__ import annotations @@ -51,6 +66,16 @@ ) from sqlalchemy.orm import backref, relationship +# Mapped is SA 2.0-only. Under SA 2.0, it MUST be importable at runtime because +# the declarative metaclass evaluates class annotations (via typing.get_type_hints +# semantics) at class-definition time. Under SA 1.4 it doesn't exist — but SA 1.4 +# also doesn't introspect mapped annotations, so silently skipping the import is +# safe (annotations stay as strings due to `from __future__ import annotations`). +try: + from sqlalchemy.orm import Mapped +except ImportError: + pass + ONLY_ONE_FIELDS = ["experiment_id", "generator_run_id"] @@ -60,80 +85,80 @@ class SQAParameter(Base): __tablename__: str = "parameter_v2" - domain_type: Column[DomainType] = Column(IntEnum(DomainType), nullable=False) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - id: Column[int] = Column(Integer, primary_key=True) - generator_run_id: Column[int | None] = Column( + domain_type: Mapped[DomainType] = Column(IntEnum(DomainType), nullable=False) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + id: Mapped[int] = Column(Integer, primary_key=True) + generator_run_id: Mapped[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - parameter_type: Column[ParameterType] = Column( + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + parameter_type: Mapped[ParameterType] = Column( IntEnum(ParameterType), nullable=False ) - is_fidelity: Column[bool | None] = Column(Boolean) - target_value: Column[TParamValue | None] = Column(JSONEncodedObject) - backfill_value: Column[TParamValue | None] = Column(JSONEncodedObject) - default_value: Column[TParamValue | None] = Column(JSONEncodedObject) + is_fidelity: Mapped[bool | None] = Column(Boolean) + target_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) + backfill_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) + default_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) # Attributes for Range Parameters - digits: Column[int | None] = Column(Integer) - log_scale: Column[bool | None] = Column(Boolean) - lower: Column[Decimal | None] = Column(Float) - upper: Column[Decimal | None] = Column(Float) + digits: Mapped[int | None] = Column(Integer) + log_scale: Mapped[bool | None] = Column(Boolean) + lower: Mapped[Decimal | None] = Column(Float) + upper: Mapped[Decimal | None] = Column(Float) # Attributes for Choice Parameters - choice_values: Column[list[TParamValue] | None] = Column(JSONEncodedList) - is_ordered: Column[bool | None] = Column(Boolean) - is_task: Column[bool | None] = Column(Boolean) - dependents: Column[dict[TParamValue, list[str]] | None] = Column(JSONEncodedObject) + choice_values: Mapped[list[TParamValue] | None] = Column(JSONEncodedList) + is_ordered: Mapped[bool | None] = Column(Boolean) + is_task: Mapped[bool | None] = Column(Boolean) + dependents: Mapped[dict[TParamValue, list[str]] | None] = Column(JSONEncodedObject) # Attributes for Fixed Parameters - fixed_value: Column[TParamValue | None] = Column(JSONEncodedObject) + fixed_value: Mapped[TParamValue | None] = Column(JSONEncodedObject) # Attribute for Derived Parameters - expression_str: Column[str | None] = Column(String(LONGTEXT_BYTES)) + expression_str: Mapped[str | None] = Column(String(LONGTEXT_BYTES)) class SQAParameterConstraint(Base): __tablename__: str = "parameter_constraint_v2" - bound: Column[Decimal] = Column(Float, nullable=False) - constraint_dict: Column[dict[str, float]] = Column(JSONEncodedDict, nullable=False) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - id: Column[int] = Column(Integer, primary_key=True) - generator_run_id: Column[int | None] = Column( + bound: Mapped[Decimal] = Column(Float, nullable=False) + constraint_dict: Mapped[dict[str, float]] = Column(JSONEncodedDict, nullable=False) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + id: Mapped[int] = Column(Integer, primary_key=True) + generator_run_id: Mapped[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) - type: Column[IntEnum] = Column(IntEnum(ParameterConstraintType), nullable=False) + type: Mapped[IntEnum] = Column(IntEnum(ParameterConstraintType), nullable=False) class SQAMetric(Base): __tablename__: str = "metric_v2" - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - generator_run_id: Column[int | None] = Column( + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + generator_run_id: Mapped[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) - id: Column[int] = Column(Integer, primary_key=True) - lower_is_better: Column[bool | None] = Column(Boolean) - intent: Column[MetricIntent] = Column(StringEnum(MetricIntent), nullable=False) - metric_type: Column[int] = Column(Integer, nullable=False) - name: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) - signature: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) + id: Mapped[int] = Column(Integer, primary_key=True) + lower_is_better: Mapped[bool | None] = Column(Boolean) + intent: Mapped[MetricIntent] = Column(StringEnum(MetricIntent), nullable=False) + metric_type: Mapped[int] = Column(Integer, nullable=False) + name: Mapped[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) + signature: Mapped[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) # Attributes for Objectives - minimize: Column[bool | None] = Column(Boolean) + minimize: Mapped[bool | None] = Column(Boolean) # Attributes for Outcome Constraints - op: Column[ComparisonOp | None] = Column(IntEnum(ComparisonOp)) - bound: Column[Decimal | None] = Column(Float) - relative: Column[bool | None] = Column(Boolean) + op: Mapped[ComparisonOp | None] = Column(IntEnum(ComparisonOp)) + bound: Mapped[Decimal | None] = Column(Float) + relative: Mapped[bool | None] = Column(Boolean) # Multi-type Experiment attributes - trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - canonical_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - scalarized_objective_id: Column[int | None] = Column( + trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + canonical_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + scalarized_objective_id: Mapped[int | None] = Column( Integer, ForeignKey("metric_v2.id") ) @@ -141,7 +166,7 @@ class SQAMetric(Base): # of Multi/Scalarized Objective contains all children of the parent metric # join_depth argument: used for loading self-referential relationships # https://docs.sqlalchemy.org/en/13/orm/self_referential.html#configuring-self-referential-eager-loading - scalarized_objective_children_metrics: list[SQAMetric] = relationship( + scalarized_objective_children_metrics: Mapped[list[SQAMetric]] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy=True, @@ -149,92 +174,94 @@ class SQAMetric(Base): ) # Attribute only defined for the children of Scalarized Objective - scalarized_objective_weight: Column[Decimal | None] = Column(Float) - scalarized_outcome_constraint_id: Column[int | None] = Column( + scalarized_objective_weight: Mapped[Decimal | None] = Column(Float) + scalarized_outcome_constraint_id: Mapped[int | None] = Column( Integer, ForeignKey("metric_v2.id") ) - scalarized_outcome_constraint_children_metrics: list[SQAMetric] = relationship( - "SQAMetric", - cascade="all, delete-orphan", - lazy=True, - foreign_keys=[scalarized_outcome_constraint_id], + scalarized_outcome_constraint_children_metrics: Mapped[list[SQAMetric]] = ( + relationship( + "SQAMetric", + cascade="all, delete-orphan", + lazy=True, + foreign_keys=[scalarized_outcome_constraint_id], + ) ) - scalarized_outcome_constraint_weight: Column[Decimal | None] = Column(Float) + scalarized_outcome_constraint_weight: Mapped[Decimal | None] = Column(Float) class SQAArm(Base): __tablename__: str = "arm_v2" - generator_run_id: Column[int] = Column( + generator_run_id: Mapped[int] = Column( Integer, ForeignKey("generator_run_v2.id"), nullable=False ) - id: Column[int] = Column(Integer, primary_key=True) - name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - parameters: Column[TParameterization] = Column(JSONEncodedTextDict, nullable=False) - weight: Column[Decimal] = Column(Float, nullable=False, default=1.0) + id: Mapped[int] = Column(Integer, primary_key=True) + name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + parameters: Mapped[TParameterization] = Column(JSONEncodedTextDict, nullable=False) + weight: Mapped[Decimal] = Column(Float, nullable=False, default=1.0) class SQAAbandonedArm(Base): __tablename__: str = "abandoned_arm_v2" - abandoned_reason: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) - id: Column[int] = Column(Integer, primary_key=True) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - time_abandoned: Column[datetime] = Column( + abandoned_reason: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + id: Mapped[int] = Column(Integer, primary_key=True) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + time_abandoned: Mapped[datetime] = Column( IntTimestamp, nullable=False, default=datetime.now ) - trial_id: Column[int] = Column(Integer, ForeignKey("trial_v2.id"), nullable=False) + trial_id: Mapped[int] = Column(Integer, ForeignKey("trial_v2.id"), nullable=False) class SQAGeneratorRun(Base): __tablename__: str = "generator_run_v2" - best_arm_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - best_arm_parameters: Column[TParameterization | None] = Column(JSONEncodedTextDict) - best_arm_predictions: Column[TModelPredictArm | None] = Column(JSONEncodedList) - generator_run_type: Column[int | None] = Column(Integer) - id: Column[int] = Column(Integer, primary_key=True) - index: Column[int | None] = Column(Integer) - model_predictions: Column[TModelPredict | None] = Column(JSONEncodedList) - time_created: Column[datetime] = Column( + best_arm_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + best_arm_parameters: Mapped[TParameterization | None] = Column(JSONEncodedTextDict) + best_arm_predictions: Mapped[TModelPredictArm | None] = Column(JSONEncodedList) + generator_run_type: Mapped[int | None] = Column(Integer) + id: Mapped[int] = Column(Integer, primary_key=True) + index: Mapped[int | None] = Column(Integer) + model_predictions: Mapped[TModelPredict | None] = Column(JSONEncodedList) + time_created: Mapped[datetime] = Column( IntTimestamp, nullable=False, default=datetime.now ) - trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id")) - weight: Column[Decimal | None] = Column(Float) - fit_time: Column[Decimal | None] = Column(Float) - gen_time: Column[Decimal | None] = Column(Float) - model_key: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - model_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - bridge_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - gen_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - model_state_after_gen: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - generation_strategy_id: Column[int | None] = Column( + trial_id: Mapped[int | None] = Column(Integer, ForeignKey("trial_v2.id")) + weight: Mapped[Decimal | None] = Column(Float) + fit_time: Mapped[Decimal | None] = Column(Float) + gen_time: Mapped[Decimal | None] = Column(Float) + model_key: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + model_kwargs: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + bridge_kwargs: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + gen_metadata: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + model_state_after_gen: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + generation_strategy_id: Mapped[int | None] = Column( Integer, ForeignKey("generation_strategy.id") ) - candidate_metadata_by_arm_signature: Column[dict[str, Any] | None] = Column( + candidate_metadata_by_arm_signature: Mapped[dict[str, Any] | None] = Column( JSONEncodedTextDict ) - generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - suggested_experiment_status: Column[ExperimentStatus | None] = Column( + generation_node_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + suggested_experiment_status: Mapped[ExperimentStatus | None] = Column( IntEnum(ExperimentStatus), nullable=True ) # relationships # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - arms: list[SQAArm] = relationship( + arms: Mapped[list[SQAArm]] = relationship( "SQAArm", cascade="all, delete-orphan", lazy="selectin", order_by=lambda: SQAArm.id, ) - metrics: list[SQAMetric] = relationship( + metrics: Mapped[list[SQAMetric]] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) - parameters: list[SQAParameter] = relationship( + parameters: Mapped[list[SQAParameter]] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) - parameter_constraints: list[SQAParameterConstraint] = relationship( + parameter_constraints: Mapped[list[SQAParameterConstraint]] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) @@ -242,31 +269,31 @@ class SQAGeneratorRun(Base): class SQARunner(Base): __tablename__: str = "runner" - id: Column[int] = Column(Integer, primary_key=True) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - properties: Column[dict[str, Any] | None] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + properties: Mapped[dict[str, Any] | None] = Column( JSONEncodedLongTextDict, default={} ) - runner_type: Column[int] = Column(Integer, nullable=False) - trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id")) + runner_type: Mapped[int] = Column(Integer, nullable=False) + trial_id: Mapped[int | None] = Column(Integer, ForeignKey("trial_v2.id")) # Multi-type Experiment attributes - trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) class SQAData(Base): __tablename__: str = "data_v2" - id: Column[int] = Column(Integer, primary_key=True) - data_json: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False) - description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - time_created: Column[int] = Column(BigInteger, nullable=False) - trial_index: Column[int | None] = Column(Integer) - generation_strategy_id: Column[int | None] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + data_json: Mapped[str] = Column(Text(LONGTEXT_BYTES), nullable=False) + description: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + time_created: Mapped[int] = Column(BigInteger, nullable=False) + trial_index: Mapped[int | None] = Column(Integer) + generation_strategy_id: Mapped[int | None] = Column( Integer, ForeignKey("generation_strategy.id") ) - structure_metadata_json: Column[str | None] = Column( + structure_metadata_json: Mapped[str | None] = Column( Text(LONGTEXT_BYTES), nullable=True ) @@ -274,17 +301,17 @@ class SQAData(Base): class SQAGenerationStrategy(Base): __tablename__: str = "generation_strategy" - id: Column[int] = Column(Integer, primary_key=True) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - steps: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=False) - curr_index: Column[int | None] = Column(Integer, nullable=True) - experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - nodes: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=True) - curr_node_name: Column[str | None] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + steps: Mapped[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=False) + curr_index: Mapped[int | None] = Column(Integer, nullable=True) + experiment_id: Mapped[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + nodes: Mapped[list[dict[str, Any]] | None] = Column(JSONEncodedList, nullable=True) + curr_node_name: Mapped[str | None] = Column( String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True ) - generator_runs: list[SQAGeneratorRun] = relationship( + generator_runs: Mapped[list[SQAGeneratorRun]] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin", @@ -295,29 +322,29 @@ class SQAGenerationStrategy(Base): class SQATrial(Base): __tablename__: str = "trial_v2" - abandoned_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - failed_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - deployed_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - experiment_id: Column[int] = Column( + abandoned_reason: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + failed_reason: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + deployed_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), nullable=False ) - id: Column[int] = Column(Integer, primary_key=True) - index: Column[int] = Column(Integer, index=True, nullable=False) - is_batch: Column[bool] = Column("is_batched", Boolean, nullable=False, default=True) - num_arms_created: Column[int] = Column(Integer, nullable=False, default=0) - ttl_seconds: Column[int | None] = Column(Integer) - run_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedLongTextDict) - stop_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - status: Column[TrialStatus] = Column( + id: Mapped[int] = Column(Integer, primary_key=True) + index: Mapped[int] = Column(Integer, index=True, nullable=False) + is_batch: Mapped[bool] = Column("is_batched", Boolean, nullable=False, default=True) + num_arms_created: Mapped[int] = Column(Integer, nullable=False, default=0) + ttl_seconds: Mapped[int | None] = Column(Integer) + run_metadata: Mapped[dict[str, Any] | None] = Column(JSONEncodedLongTextDict) + stop_metadata: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + status: Mapped[TrialStatus] = Column( IntEnum(TrialStatus), nullable=False, default=TrialStatus.CANDIDATE ) - status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - time_completed: Column[datetime | None] = Column(IntTimestamp) - time_created: Column[datetime] = Column(IntTimestamp, nullable=False) - time_staged: Column[datetime | None] = Column(IntTimestamp) - time_run_started: Column[datetime | None] = Column(IntTimestamp) - trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) + status_quo_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + time_completed: Mapped[datetime | None] = Column(IntTimestamp) + time_created: Mapped[datetime] = Column(IntTimestamp, nullable=False) + time_staged: Mapped[datetime | None] = Column(IntTimestamp) + time_run_started: Mapped[datetime | None] = Column(IntTimestamp) + trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) # relationships # Trials and experiments are mutable, so the children relationships need @@ -325,13 +352,13 @@ class SQATrial(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - abandoned_arms: list[SQAAbandonedArm] = relationship( + abandoned_arms: Mapped[list[SQAAbandonedArm]] = relationship( "SQAAbandonedArm", cascade="all, delete-orphan", lazy="selectin" ) - generator_runs: list[SQAGeneratorRun] = relationship( + generator_runs: Mapped[list[SQAGeneratorRun]] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin" ) - runner: SQARunner = relationship( + runner: Mapped[SQARunner] = relationship( "SQARunner", uselist=False, cascade="all, delete-orphan", lazy=False ) @@ -339,24 +366,24 @@ class SQATrial(Base): class SQAAuxiliaryExperiment(Base): __tablename__: str = "auxiliary_experiments" - source_experiment_id: Column[int] = Column( + source_experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), primary_key=True ) - target_experiment_id: Column[int] = Column( + target_experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), primary_key=True ) - purpose: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), primary_key=True) - is_active: Column[bool] = Column(Boolean, nullable=False) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) - time: Column[datetime] = Column(IntTimestamp, nullable=False, default=datetime.now) - source_experiment: SQAExperiment = relationship( + purpose: Mapped[str] = Column(String(LONG_STRING_FIELD_LENGTH), primary_key=True) + is_active: Mapped[bool] = Column(Boolean, nullable=False) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict) + time: Mapped[datetime] = Column(IntTimestamp, nullable=False, default=datetime.now) + source_experiment: Mapped[SQAExperiment] = relationship( "SQAExperiment", foreign_keys=[source_experiment_id], lazy="selectin", viewonly=True, innerjoin=True, ) - target_experiment: SQAExperiment = relationship( + target_experiment: Mapped[SQAExperiment] = relationship( "SQAExperiment", foreign_keys=[target_experiment_id], lazy="selectin", @@ -368,23 +395,25 @@ class SQAAuxiliaryExperiment(Base): class SQAExperiment(Base): __tablename__: str = "experiment_v2" - description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) - experiment_type: Column[int | None] = Column(Integer) - id: Column[int] = Column(Integer, primary_key=True) - is_test: Column[bool] = Column(Boolean, nullable=False, default=False) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) - status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - status_quo_parameters: Column[TParameterization | None] = Column( + description: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + experiment_type: Mapped[int | None] = Column(Integer) + id: Mapped[int] = Column(Integer, primary_key=True) + is_test: Mapped[bool] = Column(Boolean, nullable=False, default=False) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + properties: Mapped[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) + status_quo_name: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + status_quo_parameters: Mapped[TParameterization | None] = Column( JSONEncodedTextDict ) - time_created: Column[datetime] = Column(IntTimestamp, nullable=False) - status: Column[ExperimentStatus | None] = Column( + time_created: Mapped[datetime] = Column(IntTimestamp, nullable=False) + status: Mapped[ExperimentStatus | None] = Column( IntEnum(ExperimentStatus), nullable=True ) - default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True) - auxiliary_experiments_by_purpose: Column[dict[str, list[dict[str, Any]]] | None] = ( + default_trial_type: Mapped[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + default_data_type: Mapped[DataType | None] = Column( + IntEnum(DataType), nullable=True + ) + auxiliary_experiments_by_purpose: Mapped[dict[str, list[dict[str, Any]]] | None] = ( Column(JSONEncodedTextDict, nullable=True, default={}) ) @@ -394,38 +423,38 @@ class SQAExperiment(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - data: list[SQAData] = relationship( + data: Mapped[list[SQAData]] = relationship( "SQAData", cascade="all, delete-orphan", lazy="selectin" ) - metrics: list[SQAMetric] = relationship( + metrics: Mapped[list[SQAMetric]] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) - parameters: list[SQAParameter] = relationship( + parameters: Mapped[list[SQAParameter]] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) - parameter_constraints: list[SQAParameterConstraint] = relationship( + parameter_constraints: Mapped[list[SQAParameterConstraint]] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) - runners: list[SQARunner] = relationship( + runners: Mapped[list[SQARunner]] = relationship( "SQARunner", cascade="all, delete-orphan", lazy=False ) - trials: list[SQATrial] = relationship( + trials: Mapped[list[SQATrial]] = relationship( "SQATrial", cascade="all, delete-orphan", lazy="selectin" ) - generation_strategy: SQAGenerationStrategy | None = relationship( + generation_strategy: Mapped[SQAGenerationStrategy | None] = relationship( "SQAGenerationStrategy", backref=backref("experiment", lazy=True), uselist=False, lazy=True, ) - auxiliary_experiments: list[SQAAuxiliaryExperiment] = relationship( + auxiliary_experiments: Mapped[list[SQAAuxiliaryExperiment]] = relationship( "SQAAuxiliaryExperiment", cascade="all, delete-orphan", lazy="selectin", foreign_keys=[SQAAuxiliaryExperiment.target_experiment_id], ) - analysis_cards: list[SQAAnalysisCard] = relationship( + analysis_cards: Mapped[list[SQAAnalysisCard]] = relationship( "SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin" ) @@ -433,35 +462,35 @@ class SQAExperiment(Base): class SQAAnalysisCard(Base): __tablename__: str = "analysis_card_v2" - id: Column[int] = Column(Integer, primary_key=True) + id: Mapped[int] = Column(Integer, primary_key=True) - experiment_id: Column[int] = Column( + experiment_id: Mapped[int] = Column( Integer, ForeignKey("experiment_v2.id"), nullable=False ) - name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - timestamp: Column[datetime] = Column(IntTimestamp, nullable=False) + name: Mapped[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) + timestamp: Mapped[datetime] = Column(IntTimestamp, nullable=False) - parent_id: Column[int | None] = Column( + parent_id: Mapped[int | None] = Column( Integer, ForeignKey("analysis_card_v2.id"), nullable=True, ) - order: Column[int | None] = Column(Integer, nullable=True) + order: Mapped[int | None] = Column(Integer, nullable=True) - title: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=True) - subtitle: Column[str | None] = Column(Text, nullable=True) - dataframe_json: Column[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) - blob: Column[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) - blob_annotation: Column[str | None] = Column( + title: Mapped[str | None] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=True) + subtitle: Mapped[str | None] = Column(Text, nullable=True) + dataframe_json: Mapped[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) + blob: Mapped[str | None] = Column(Text(LONGTEXT_BYTES), nullable=True) + blob_annotation: Mapped[str | None] = Column( String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True ) - parent: Any = relationship( + parent: Mapped[SQAAnalysisCard | None] = relationship( "SQAAnalysisCard", back_populates="children", remote_side=[id], lazy="selectin", ) - children: list[Any] = relationship( + children: Mapped[list[SQAAnalysisCard]] = relationship( "SQAAnalysisCard", cascade="all, delete-orphan", back_populates="parent", diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 4e1e08b7303..5f3c50540b4 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -282,12 +282,10 @@ def test_generator_run_type_validation(self) -> None: generator_run._generator_run_type = "STATUS_QUO" generator_run_sqa = self.encoder.generator_run_to_sqa(generator_run) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. generator_run_sqa.generator_run_type = 2 with self.assertRaises(SQADecodeError): self.decoder.generator_run_from_sqa(generator_run_sqa, False, False) - # pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine. generator_run_sqa.generator_run_type = 0 self.decoder.generator_run_from_sqa(generator_run_sqa, False, False) diff --git a/ax/storage/sqa_store/validation.py b/ax/storage/sqa_store/validation.py index 9346ccbe6c8..0c8bb4b411f 100644 --- a/ax/storage/sqa_store/validation.py +++ b/ax/storage/sqa_store/validation.py @@ -34,6 +34,8 @@ def listens_for_multiple( + # pyre-ignore[24]: SA 2.0 requires a type param; SA 1.4 InstrumentedAttribute + # is not subscriptable at runtime, so we keep the bare form. targets: list[InstrumentedAttribute], identifier: str, *args: Any,