From 5f02bbf6f038b3df6d06c4684e0c50e982d2c04b Mon Sep 17 00:00:00 2001 From: bjorkbjork Date: Sat, 2 May 2026 18:47:24 +1000 Subject: [PATCH 1/2] add opt-in autogenerate support for CHECK constraint detection Add a new comparator plugin that detects added and removed named CHECK constraints during autogenerate. Comparison is name-only by default; the dialect hook compare_check_constraint exists for future per-dialect expression comparison. The plugin is registered under alembic.ext.checkconstraint and must be explicitly opted into via autogenerate_plugins: context.configure( autogenerate_plugins=[ "alembic.autogenerate.*", "alembic.ext.checkconstraint", ] ) Unnamed constraints, type-bound constraints (Boolean/Enum), and dialects that do not support check constraint reflection are handled gracefully. Fixes: #508 --- alembic/autogenerate/compare/__init__.py | 4 + .../autogenerate/compare/check_constraints.py | 186 +++++ alembic/autogenerate/compare/util.py | 25 + alembic/autogenerate/render.py | 20 +- alembic/ddl/_autogen.py | 35 + alembic/ddl/impl.py | 12 +- alembic/testing/requirements.py | 4 + tests/test_autogen_check_constraints.py | 744 ++++++++++++++++++ 8 files changed, 1027 insertions(+), 3 deletions(-) create mode 100644 alembic/autogenerate/compare/check_constraints.py create mode 100644 tests/test_autogen_check_constraints.py diff --git a/alembic/autogenerate/compare/__init__.py b/alembic/autogenerate/compare/__init__.py index a49640cf..30ed40ba 100644 --- a/alembic/autogenerate/compare/__init__.py +++ b/alembic/autogenerate/compare/__init__.py @@ -3,6 +3,7 @@ import logging from typing import TYPE_CHECKING +from . import check_constraints from . import comments from . import constraints from . import schema @@ -60,3 +61,6 @@ def _produce_net_changes( server_defaults, "alembic.autogenerate.defaults" ) Plugin.setup_plugin_from_module(comments, "alembic.autogenerate.comments") +Plugin.setup_plugin_from_module( + check_constraints, "alembic.ext.checkconstraint" +) diff --git a/alembic/autogenerate/compare/check_constraints.py b/alembic/autogenerate/compare/check_constraints.py new file mode 100644 index 00000000..b8a8e951 --- /dev/null +++ b/alembic/autogenerate/compare/check_constraints.py @@ -0,0 +1,186 @@ +# mypy: allow-untyped-defs, allow-untyped-calls, allow-incomplete-defs + +from __future__ import annotations + +import logging +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import schema as sa_schema + +from .util import _InspectorConv +from ...operations import ops +from ...util import PriorityDispatchResult +from ...util import sqla_compat + +if TYPE_CHECKING: + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import CheckConstraint + from sqlalchemy.sql.schema import Table + + from ...autogenerate.api import AutogenContext + from ...ddl.impl import DefaultImpl + from ...operations.ops import ModifyTableOps + from ...runtime.plugins import Plugin + + +log = logging.getLogger(__name__) + + +def _make_check_constraint( + impl: DefaultImpl, + params: dict, + conn_table: Table, +) -> CheckConstraint: + const = sa_schema.CheckConstraint( + params["sqltext"], + name=params["name"], + **impl.adjust_reflected_dialect_options(params, "check_constraint"), + ) + conn_table.append_constraint(const) + return const + + +def _compare_check_constraints( + autogen_context: AutogenContext, + modify_table_ops: ModifyTableOps, + schema: Optional[str], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], +) -> PriorityDispatchResult: + if conn_table is None or metadata_table is None: + return PriorityDispatchResult.CONTINUE + + inspector = autogen_context.inspector + impl = autogen_context.migration_context.impl + + metadata_ck_constraints = { + ck + for ck in metadata_table.constraints + if isinstance(ck, sa_schema.CheckConstraint) + and not sqla_compat._is_type_bound(ck) + } + + try: + conn_ck_list = _InspectorConv(inspector).get_check_constraints( + tname, schema=schema + ) + except NotImplementedError: + return PriorityDispatchResult.CONTINUE + + conn_ck_list = [ + ck + for ck in conn_ck_list + if ck.get("name") is not None + and autogen_context.run_name_filters( + ck["name"], + "check_constraint", + {"table_name": tname, "schema_name": schema}, + ) + ] + + conn_ck_objs = { + _make_check_constraint(impl, ck_def, conn_table) + for ck_def in conn_ck_list + } + + metadata_ck_sig = { + impl._create_metadata_constraint_sig(ck) + for ck in metadata_ck_constraints + if sqla_compat._constraint_is_named(ck, autogen_context.dialect) + } + + conn_ck_sig = { + impl._create_reflected_constraint_sig(ck) for ck in conn_ck_objs + } + + metadata_ck_by_name = {c.name: c for c in metadata_ck_sig if c.name} + conn_ck_by_name = {c.name: c for c in conn_ck_sig if c.name} + + for removed_name in sorted( + set(conn_ck_by_name).difference(metadata_ck_by_name) + ): + conn_obj = conn_ck_by_name[removed_name] + if autogen_context.run_object_filters( + conn_obj.const, + conn_obj.name, + "check_constraint", + True, + None, + ): + modify_table_ops.ops.append( + ops.DropConstraintOp.from_constraint(conn_obj.const) + ) + log.info( + "Detected removed check constraint %r on table %r", + conn_obj.name, + tname, + ) + + for existing_name in sorted( + set(metadata_ck_by_name).intersection(conn_ck_by_name) + ): + metadata_obj = metadata_ck_by_name[existing_name] + conn_obj = conn_ck_by_name[existing_name] + + comparison = metadata_obj.compare_to_reflected(conn_obj) + + if comparison.is_different: + if autogen_context.run_object_filters( + metadata_obj.const, + metadata_obj.name, + "check_constraint", + False, + conn_obj.const, + ): + log.info( + "Detected changed check constraint %r on table %r: %s", + existing_name, + tname, + comparison.message, + ) + modify_table_ops.ops.append( + ops.DropConstraintOp.from_constraint(conn_obj.const) + ) + modify_table_ops.ops.append( + ops.AddConstraintOp.from_constraint(metadata_obj.const) + ) + elif comparison.is_skip: + log.info( + "Cannot compare check constraint %r, " + "assuming equal and skipping. %s", + existing_name, + comparison.message, + ) + + for added_name in sorted( + set(metadata_ck_by_name).difference(conn_ck_by_name) + ): + metadata_obj = metadata_ck_by_name[added_name] + if autogen_context.run_object_filters( + metadata_obj.const, + metadata_obj.name, + "check_constraint", + False, + None, + ): + modify_table_ops.ops.append( + ops.AddConstraintOp.from_constraint(metadata_obj.const) + ) + log.info( + "Detected added check constraint %r on table %r", + metadata_obj.name, + tname, + ) + + return PriorityDispatchResult.CONTINUE + + +def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + _compare_check_constraints, + "table", + "checkconstraints", + ) diff --git a/alembic/autogenerate/compare/util.py b/alembic/autogenerate/compare/util.py index 41829c0e..dfec3685 100644 --- a/alembic/autogenerate/compare/util.py +++ b/alembic/autogenerate/compare/util.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from sqlalchemy import Table from sqlalchemy.engine import Inspector + from sqlalchemy.engine.interfaces import ReflectedCheckConstraint from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint from sqlalchemy.engine.interfaces import ReflectedIndex from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint @@ -78,6 +79,11 @@ def get_foreign_keys( ) -> list[ReflectedForeignKeyConstraint]: raise NotImplementedError() + def get_check_constraints( + self, tname: str, schema: str | None + ) -> list[ReflectedCheckConstraint]: + raise NotImplementedError() + def reflect_table(self, table: Table) -> None: raise NotImplementedError() @@ -123,6 +129,13 @@ def get_foreign_keys( self.inspector.get_foreign_keys(tname, schema=schema) ) + def get_check_constraints( + self, tname: str, schema: str | None + ) -> list[ReflectedCheckConstraint]: + return self._apply_reflectinfo_conv( + self.inspector.get_check_constraints(tname, schema=schema) + ) + def reflect_table(self, table: Table) -> None: self.inspector.reflect_table(table, include_columns=None) @@ -252,6 +265,18 @@ def get_foreign_keys( apply_constraint_conv=True, ) + def get_check_constraints( + self, tname: str, schema: str | None + ) -> list[ReflectedCheckConstraint]: + return self._return_from_cache( + tname, + schema, + "alembic_check_constraints", + self.inspector.get_check_constraints, + apply_constraint_conv=True, + optional=False, + ) + def _apply_reflectinfo_conv(self, consts): if not consts: return consts diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 4cae1cf3..489314e3 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -438,8 +438,24 @@ def _add_pk_constraint(constraint, autogen_context): @renderers.dispatch_for(ops.CreateCheckConstraintOp) -def _add_check_constraint(constraint, autogen_context): - raise NotImplementedError() +def _add_check_constraint( + autogen_context: AutogenContext, op: ops.CreateCheckConstraintOp +) -> str: + constraint = op.to_constraint() + args = [repr(_render_gen_name(autogen_context, op.constraint_name))] + if not autogen_context._has_batch: + args.append(repr(_ident(op.table_name))) + args.append( + _render_potential_expr( + constraint.sqltext, autogen_context, wrap_in_element=False + ) + ) + if not autogen_context._has_batch and op.schema: + args.append("schema=%r" % _ident(op.schema)) + return "%(prefix)screate_check_constraint(%(args)s)" % { + "prefix": _alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } @renderers.dispatch_for(ops.DropConstraintOp) diff --git a/alembic/ddl/_autogen.py b/alembic/ddl/_autogen.py index 74715b18..417e18a9 100644 --- a/alembic/ddl/_autogen.py +++ b/alembic/ddl/_autogen.py @@ -16,6 +16,7 @@ from typing import TypeVar from typing import Union +from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index @@ -86,6 +87,7 @@ class _constraint_sig(Generic[_C]): _is_index: ClassVar[bool] = False _is_fk: ClassVar[bool] = False _is_uq: ClassVar[bool] = False + _is_ck: ClassVar[bool] = False _is_metadata: bool @@ -325,5 +327,38 @@ def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]: return sig._is_uq +class _ck_constraint_sig(_constraint_sig[CheckConstraint]): + _is_ck = True + + @classmethod + def _register(cls) -> None: + _clsreg["check_constraint"] = cls + _clsreg["table_or_column_check_constraint"] = cls + _clsreg["column_check_constraint"] = cls + + def __init__( + self, + is_metadata: bool, + impl: DefaultImpl, + const: CheckConstraint, + ) -> None: + self._is_metadata = is_metadata + self.impl = impl + self.const = const + self.name = sqla_compat.constraint_name_or_none(const.name) + self._sig = (self.name,) + + def _compare_to_reflected( + self, other: _constraint_sig[_C] + ) -> ComparisonResult: + assert self._is_metadata + assert is_ck_sig(other) + return self.impl.compare_check_constraint(self.const, other.const) + + +def is_ck_sig(sig: _constraint_sig) -> TypeGuard[_ck_constraint_sig]: + return sig._is_ck + + def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]: return sig._is_fk diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 964cd1f3..c77bd555 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -43,6 +43,7 @@ from sqlalchemy.engine import Connection from sqlalchemy.engine import Dialect from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.engine.interfaces import ReflectedCheckConstraint from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint from sqlalchemy.engine.interfaces import ReflectedIndex from sqlalchemy.engine.interfaces import ReflectedPrimaryKeyConstraint @@ -51,6 +52,7 @@ from sqlalchemy.sql import ClauseElement from sqlalchemy.sql import Executable from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index @@ -64,7 +66,8 @@ from ..operations.batch import BatchOperationsImpl _ReflectedConstraint = ( - ReflectedForeignKeyConstraint + ReflectedCheckConstraint + | ReflectedForeignKeyConstraint | ReflectedPrimaryKeyConstraint | ReflectedIndex | ReflectedUniqueConstraint @@ -840,6 +843,13 @@ def compare_unique_constraint( else: return ComparisonResult.Equal() + def compare_check_constraint( + self, + metadata_constraint: CheckConstraint, + reflected_constraint: CheckConstraint, + ) -> ComparisonResult: + return ComparisonResult.Equal() + def _skip_functional_indexes(self, metadata_indexes, conn_indexes): conn_indexes_by_name = {c.name: c for c in conn_indexes} diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 1b217c93..e087900a 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -65,6 +65,10 @@ def check_constraints_w_enforcement(self): return exclusions.open() + @property + def check_constraint_reflection(self): + return exclusions.open() + @property def reflects_pk_names(self): return exclusions.closed() diff --git a/tests/test_autogen_check_constraints.py b/tests/test_autogen_check_constraints.py new file mode 100644 index 00000000..7436a426 --- /dev/null +++ b/tests/test_autogen_check_constraints.py @@ -0,0 +1,744 @@ +from sqlalchemy import Boolean +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table + +from alembic import autogenerate +from alembic.autogenerate import api +from alembic.migration import MigrationContext +from alembic.operations import ops +from alembic.testing import config +from alembic.testing import eq_ +from alembic.testing import TestBase +from alembic.testing import util +from alembic.testing.env import clear_staging_env +from alembic.testing.env import staging_env +from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest + + +_ck_plugin_opts = { + "autogenerate_plugins": [ + "alembic.autogenerate.*", + "alembic.ext.checkconstraint", + ] +} + + +class AutogenCheckConstraintTest(AutogenFixtureTest, TestBase): + __backend__ = True + __requires__ = ("check_constraint_reflection",) + + def test_add_check_constraint(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(len(diffs), 1) + eq_(diffs[0][0], "add_constraint") + eq_(diffs[0][1].name, "ck_t_x_positive") + + def test_remove_check_constraint(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(len(diffs), 1) + eq_(diffs[0][0], "remove_constraint") + eq_(diffs[0][1].name, "ck_t_x_positive") + + def test_same_name_different_expression_no_change(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 5", name="ck_t_x_positive"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(diffs, []) + + def test_no_change_check_constraint(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(diffs, []) + + def test_unnamed_check_constraint_in_metadata_ignored(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(diffs, []) + + def test_type_bound_boolean_not_detected(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + Column("flag", Boolean(create_constraint=True)), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + check_diffs = [ + d + for d in diffs + if d[0] in ("add_constraint", "remove_constraint") + and isinstance(d[1], CheckConstraint) + ] + eq_(check_diffs, []) + + def test_multiple_check_constraints(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + Column("y", Integer), + CheckConstraint("x > 0", name="ck_x"), + ) + + Table( + "t", + m2, + Column("x", Integer), + Column("y", Integer), + CheckConstraint("x > 0", name="ck_x"), + CheckConstraint("y > 0", name="ck_y"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(len(diffs), 1) + eq_(diffs[0][0], "add_constraint") + eq_(diffs[0][1].name, "ck_y") + + def test_remove_one_of_multiple(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + Column("y", Integer), + CheckConstraint("x > 0", name="ck_x"), + CheckConstraint("y > 0", name="ck_y"), + ) + + Table( + "t", + m2, + Column("x", Integer), + Column("y", Integer), + CheckConstraint("x > 0", name="ck_x"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(len(diffs), 1) + eq_(diffs[0][0], "remove_constraint") + eq_(diffs[0][1].name, "ck_y") + + def test_add_table_with_check_constraint_no_duplicate(self): + m1 = MetaData() + m2 = MetaData() + + Table("t", m1, Column("x", Integer)) + + Table("t", m2, Column("x", Integer)) + Table( + "new_table", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_new_x"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + add_table = [d for d in diffs if d[0] == "add_table"] + eq_(len(add_table), 1) + eq_(add_table[0][1].name, "new_table") + + new_table = add_table[0][1] + ck_in_table = [ + c + for c in new_table.constraints + if isinstance(c, CheckConstraint) and c.name == "ck_new_x" + ] + eq_(len(ck_in_table), 1) + + add_ck = [ + d + for d in diffs + if d[0] == "add_constraint" and isinstance(d[1], CheckConstraint) + ] + eq_(add_ck, []) + + def test_drop_table_with_check_constraint_no_duplicate(self): + m1 = MetaData() + m2 = MetaData() + + Table("t", m1, Column("x", Integer)) + Table( + "old_table", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_old_x"), + ) + + Table("t", m2, Column("x", Integer)) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + drop_table = [d for d in diffs if d[0] == "remove_table"] + eq_(len(drop_table), 1) + eq_(drop_table[0][1].name, "old_table") + + old_table = drop_table[0][1] + ck_in_table = [ + c + for c in old_table.constraints + if isinstance(c, CheckConstraint) and c.name == "ck_old_x" + ] + eq_(len(ck_in_table), 1) + + drop_ck = [ + d + for d in diffs + if d[0] == "remove_constraint" + and isinstance(d[1], CheckConstraint) + ] + eq_(drop_ck, []) + + +class AutogenCheckConstraintFilterTest(AutogenFixtureTest, TestBase): + __backend__ = True + __requires__ = ("check_constraint_reflection",) + + def test_include_name_excludes_reflected_check_constraint(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + ) + + def include_name(name, type_, parent_names): + if type_ == "check_constraint": + return False + return True + + diffs = self._fixture( + m1, + m2, + name_filters=include_name, + opts=_ck_plugin_opts, + ) + + check_diffs = [ + d + for d in diffs + if d[0] in ("add_constraint", "remove_constraint") + and isinstance(d[1], CheckConstraint) + ] + eq_(check_diffs, []) + + def test_include_object_excludes_add(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "check_constraint": + return False + return True + + diffs = self._fixture( + m1, + m2, + object_filters=include_object, + opts=_ck_plugin_opts, + ) + + check_diffs = [ + d + for d in diffs + if d[0] in ("add_constraint", "remove_constraint") + and isinstance(d[1], CheckConstraint) + ] + eq_(check_diffs, []) + + def test_include_object_excludes_remove(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + ) + + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "check_constraint": + return False + return True + + diffs = self._fixture( + m1, + m2, + object_filters=include_object, + opts=_ck_plugin_opts, + ) + + check_diffs = [ + d + for d in diffs + if d[0] in ("add_constraint", "remove_constraint") + and isinstance(d[1], CheckConstraint) + ] + eq_(check_diffs, []) + + def test_include_object_receives_correct_args_for_add(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + calls = [] + + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "check_constraint": + calls.append((name, type_, reflected, compare_to)) + return True + + self._fixture( + m1, + m2, + object_filters=include_object, + opts=_ck_plugin_opts, + ) + + eq_(len(calls), 1) + eq_(calls[0][0], "ck_t_x_positive") + eq_(calls[0][1], "check_constraint") + eq_(calls[0][2], False) + eq_(calls[0][3], None) + + def test_include_object_receives_correct_args_for_remove(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + ) + + calls = [] + + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "check_constraint": + calls.append((name, type_, reflected, compare_to)) + return True + + self._fixture( + m1, + m2, + object_filters=include_object, + opts=_ck_plugin_opts, + ) + + eq_(len(calls), 1) + eq_(calls[0][0], "ck_t_x_positive") + eq_(calls[0][1], "check_constraint") + eq_(calls[0][2], True) + eq_(calls[0][3], None) + + +class AutogenCheckConstraintNoReflectionTest(AutogenFixtureTest, TestBase): + __backend__ = True + + def setUp(self): + staging_env() + self.bind = eng = util.testing_engine() + + def unimpl(*arg, **kw): + raise NotImplementedError() + + eng.dialect.get_check_constraints = unimpl + + def test_no_reflection_graceful_skip_add(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + check_diffs = [ + d + for d in diffs + if d[0] in ("add_constraint", "remove_constraint") + and isinstance(d[1], CheckConstraint) + ] + eq_(check_diffs, []) + + def test_no_reflection_graceful_skip_remove(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + check_diffs = [ + d + for d in diffs + if d[0] in ("add_constraint", "remove_constraint") + and isinstance(d[1], CheckConstraint) + ] + eq_(check_diffs, []) + + +class AutogenCheckConstraintRenderTest(TestBase): + + def setUp(self): + staging_env() + self.bind = config.db + + ctx_opts = { + "sqlalchemy_module_prefix": "sa.", + "alembic_module_prefix": "op.", + "target_metadata": MetaData(), + } + context = MigrationContext.configure( + dialect_name=self.bind.dialect.name, opts=ctx_opts + ) + self.autogen_context = api.AutogenContext(context) + + def tearDown(self): + clear_staging_env() + + def test_render_add_check_constraint(self): + m = MetaData() + t = Table("t", m, Column("x", Integer)) + ck = CheckConstraint(t.c.x > 0, name="ck_x_positive") + op_obj = ops.CreateCheckConstraintOp.from_constraint(ck) + + result = autogenerate.render_op_text(self.autogen_context, op_obj) + + assert "op.create_check_constraint(" in result + assert "'ck_x_positive'" in result + assert "'t'" in result + + def test_render_add_check_constraint_string_sqltext(self): + m = MetaData() + t = Table("t", m, Column("x", Integer)) + ck = CheckConstraint("x > 0", name="ck_x_positive") + t.append_constraint(ck) + op_obj = ops.CreateCheckConstraintOp.from_constraint(ck) + + result = autogenerate.render_op_text(self.autogen_context, op_obj) + + assert "op.create_check_constraint(" in result + assert "'ck_x_positive'" in result + + def test_render_drop_check_constraint(self): + m = MetaData() + t = Table("t", m, Column("x", Integer)) + ck = CheckConstraint(t.c.x > 0, name="ck_x_positive") + op_obj = ops.DropConstraintOp.from_constraint(ck) + + result = autogenerate.render_op_text(self.autogen_context, op_obj) + + assert "op.drop_constraint(" in result + assert "'ck_x_positive'" in result + + def test_render_add_check_constraint_with_schema(self): + m = MetaData() + t = Table("t", m, Column("x", Integer), schema="test_schema") + ck = CheckConstraint(t.c.x > 0, name="ck_x_positive") + op_obj = ops.CreateCheckConstraintOp.from_constraint(ck) + + result = autogenerate.render_op_text(self.autogen_context, op_obj) + + assert "op.create_check_constraint(" in result + assert "'ck_x_positive'" in result + assert "schema='test_schema'" in result + + +class AutogenCheckConstraintPluginOptInTest(AutogenFixtureTest, TestBase): + __backend__ = True + __requires__ = ("check_constraint_reflection",) + + def test_default_plugins_do_not_detect(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + diffs = self._fixture(m1, m2) + + check_diffs = [ + d + for d in diffs + if d[0] in ("add_constraint", "remove_constraint") + and isinstance(d[1], CheckConstraint) + ] + eq_(check_diffs, []) + + def test_opted_in_plugin_does_detect(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "t", + m1, + Column("x", Integer), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(len(diffs), 1) + eq_(diffs[0][0], "add_constraint") + eq_(diffs[0][1].name, "ck_t_x_positive") + + +class AutogenCheckConstraintNamingConvTest(AutogenFixtureTest, TestBase): + __backend__ = True + __requires__ = ("check_constraint_reflection",) + + def test_add_named_via_convention(self): + m1 = MetaData() + m2 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + + Table("t", m1, Column("x", Integer)) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="x_positive"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(len(diffs), 1) + eq_(diffs[0][0], "add_constraint") + eq_(diffs[0][1].name, "ck_t_x_positive") + + def test_remove_named_via_convention(self): + m1 = MetaData() + m2 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table("t", m2, Column("x", Integer)) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(len(diffs), 1) + eq_(diffs[0][0], "remove_constraint") + eq_(diffs[0][1].name, "ck_t_x_positive") + + def test_no_change_named_via_convention(self): + m1 = MetaData() + m2 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + + Table( + "t", + m1, + Column("x", Integer), + CheckConstraint("x > 0", name="ck_t_x_positive"), + ) + + Table( + "t", + m2, + Column("x", Integer), + CheckConstraint("x > 0", name="x_positive"), + ) + + diffs = self._fixture(m1, m2, opts=_ck_plugin_opts) + + eq_(diffs, []) From b0c36edd1ba11d31236e70b0a821e9a1c926275a Mon Sep 17 00:00:00 2001 From: bjorkbjork Date: Mon, 4 May 2026 13:47:38 +1000 Subject: [PATCH 2/2] fix pep484 CI: add type annotations for check constraint autogenerate --- alembic/autogenerate/compare/check_constraints.py | 15 ++++++++++++--- alembic/runtime/environment.py | 1 + 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/alembic/autogenerate/compare/check_constraints.py b/alembic/autogenerate/compare/check_constraints.py index b8a8e951..81133120 100644 --- a/alembic/autogenerate/compare/check_constraints.py +++ b/alembic/autogenerate/compare/check_constraints.py @@ -15,6 +15,7 @@ from ...util import sqla_compat if TYPE_CHECKING: + from sqlalchemy.engine.interfaces import ReflectedCheckConstraint from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.sql.schema import Table @@ -30,7 +31,7 @@ def _make_check_constraint( impl: DefaultImpl, - params: dict, + params: ReflectedCheckConstraint, conn_table: Table, ) -> CheckConstraint: const = sa_schema.CheckConstraint( @@ -96,8 +97,16 @@ def _compare_check_constraints( impl._create_reflected_constraint_sig(ck) for ck in conn_ck_objs } - metadata_ck_by_name = {c.name: c for c in metadata_ck_sig if c.name} - conn_ck_by_name = {c.name: c for c in conn_ck_sig if c.name} + metadata_ck_by_name = { + c.name: c + for c in metadata_ck_sig + if sqla_compat.constraint_name_string(c.name) + } + conn_ck_by_name = { + c.name: c + for c in conn_ck_sig + if sqla_compat.constraint_name_string(c.name) + } for removed_name in sorted( set(conn_ck_by_name).difference(metadata_ck_by_name) diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 5817e2d9..49254dbb 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -58,6 +58,7 @@ "index", "unique_constraint", "foreign_key_constraint", + "check_constraint", ] NameFilterParentNames = MutableMapping[ Literal["schema_name", "table_name", "schema_qualified_table_name"],