From a38bd7bbcc21b766a7e0ff269e2b7ffcfd0c7dfd Mon Sep 17 00:00:00 2001 From: Fayaz Yusuf Khan Date: Sun, 22 Jun 2025 01:23:25 -0400 Subject: [PATCH 1/6] Remove warning suppression --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index c9865be..ec7a560 100644 --- a/noxfile.py +++ b/noxfile.py @@ -110,7 +110,7 @@ def test(session, sqlalchemy): session.install(f"sqlalchemy~={sqlalchemy}.0") session.install("-e", ".") pytest_args = session.posargs or ["--pyargs", "sqlalchemy_mptt"] - session.run("pytest", *pytest_args, env={"SQLALCHEMY_SILENCE_UBER_WARNING": "1"}) + session.run("pytest", *pytest_args) @nox.session(default=False) From 31d6a838ebb102800e450a89611a6535f86d3c60 Mon Sep 17 00:00:00 2001 From: Fayaz Yusuf Khan Date: Sun, 22 Jun 2025 13:36:20 -0400 Subject: [PATCH 2/6] Show all SQLA 2.x warnings --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index ec7a560..6a2c2dc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -110,7 +110,7 @@ def test(session, sqlalchemy): session.install(f"sqlalchemy~={sqlalchemy}.0") session.install("-e", ".") pytest_args = session.posargs or ["--pyargs", "sqlalchemy_mptt"] - session.run("pytest", *pytest_args) + session.run("pytest", *pytest_args, env={"SQLALCHEMY_WARN_20": "1"}) @nox.session(default=False) From 48c8af837ee5484d18683b0dca1b386391de47cb Mon Sep 17 00:00:00 2001 From: Fayaz Yusuf Khan Date: Sun, 22 Jun 2025 13:39:15 -0400 Subject: [PATCH 3/6] Add compatibility layer for legacy SQLA versions --- sqlalchemy_mptt/sqlalchemy_compat.py | 40 ++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 sqlalchemy_mptt/sqlalchemy_compat.py diff --git a/sqlalchemy_mptt/sqlalchemy_compat.py b/sqlalchemy_mptt/sqlalchemy_compat.py new file mode 100644 index 0000000..a4083b2 --- /dev/null +++ b/sqlalchemy_mptt/sqlalchemy_compat.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# +# Copyright (c) 2025 Fayaz Yusuf Khan +# Distributed under terms of the MIT license. +"""Compatibility layer for SQLAlchemy versions.""" +import sqlalchemy as sa + + +if sa.__version__ < '1.4': + from sqlalchemy.ext.declarative import declarative_base +else: + from sqlalchemy.orm import declarative_base + + +def select(*args, **kwargs): + """Compatibility function for select.""" + if sa.__version__ < '1.4': + return sa.select(args, **kwargs) + else: + return sa.select(*args, **kwargs) + + +def case(*args, **kwargs): + """Compatibility function for case.""" + if sa.__version__ < '1.4': + return sa.case(args, **kwargs) + else: + return sa.case(*args, **kwargs) + + +def get(session, model, id): + """Compatibility function for getting an object by ID.""" + if sa.__version__ < '1.4': + return session.query(model).get(id) + else: + return session.get(model, id) + + +__all__ = ["case", "declarative_base", "select"] From e4535ec3880861c71ad1849dc5132e6a21b0903e Mon Sep 17 00:00:00 2001 From: Fayaz Yusuf Khan Date: Sun, 22 Jun 2025 13:48:12 -0400 Subject: [PATCH 4/6] Address warnings --- sqlalchemy_mptt/events.py | 185 ++++++++-------------- sqlalchemy_mptt/tests/test_events.py | 7 +- sqlalchemy_mptt/tests/test_inheritance.py | 21 +-- sqlalchemy_mptt/tests/test_mixins.py | 4 +- sqlalchemy_mptt/tests/test_stateful.py | 7 +- 5 files changed, 89 insertions(+), 135 deletions(-) diff --git a/sqlalchemy_mptt/events.py b/sqlalchemy_mptt/events.py index cdd1d86..a825a21 100644 --- a/sqlalchemy_mptt/events.py +++ b/sqlalchemy_mptt/events.py @@ -13,11 +13,13 @@ import weakref # SQLAlchemy -from sqlalchemy import and_, case, event, select, inspection +from sqlalchemy import and_, event, inspection from sqlalchemy.orm import object_session from sqlalchemy.sql import func from sqlalchemy.orm.base import NO_VALUE +from sqlalchemy_mptt.sqlalchemy_compat import case, select + def _insert_subtree( table, @@ -41,9 +43,9 @@ def _insert_subtree( delta_rgt = delta_lft + node_size - 1 connection.execute( - table.update( - table_pk.in_(subtree) - ).values( + table.update() + .where(table_pk.in_(subtree)) + .values( lft=table.c.lft - node_pos_left + delta_lft, rgt=table.c.rgt - node_pos_right + delta_rgt, level=table.c.level - node_level + parent_level + 1, @@ -53,21 +55,14 @@ def _insert_subtree( # step 2: update key of right side connection.execute( - table.update( - and_( - table.c.rgt > delta_lft - 1, - table_pk.notin_(subtree), - table.c.tree_id == parent_tree_id - ) - ).values( + table.update() + .where(table.c.rgt > delta_lft - 1) + .where(table_pk.notin_(subtree)) + .where(table.c.tree_id == parent_tree_id) + .values( rgt=table.c.rgt + node_size, lft=case( - [ - ( - table.c.lft > left_sibling['lft'], - table.c.lft + node_size - ) - ], + (table.c.lft > left_sibling['lft'], table.c.lft + node_size), else_=table.c.lft ) ) @@ -94,9 +89,7 @@ def mptt_before_insert(mapper, connection, instance): instance.level = instance.get_default_level() tree_id = connection.scalar( select( - [ - func.max(table.c.tree_id) + 1 - ] + func.max(table.c.tree_id) + 1 ) ) or 1 instance.tree_id = tree_id @@ -106,12 +99,10 @@ def mptt_before_insert(mapper, connection, instance): parent_tree_id, parent_level) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.tree_id, - table.c.level - ] + table.c.lft, + table.c.rgt, + table.c.tree_id, + table.c.level ).where( table_pk == instance.parent_id ) @@ -119,26 +110,16 @@ def mptt_before_insert(mapper, connection, instance): # Update key of right side connection.execute( - table.update( - and_(table.c.rgt >= parent_pos_right, - table.c.tree_id == parent_tree_id) - ).values( + table.update() + .where(table.c.rgt >= parent_pos_right) + .where(table.c.tree_id == parent_tree_id) + .values( lft=case( - [ - ( - table.c.lft > parent_pos_right, - table.c.lft + 2 - ) - ], + (table.c.lft > parent_pos_right, table.c.lft + 2), else_=table.c.lft ), rgt=case( - [ - ( - table.c.rgt >= parent_pos_right, - table.c.rgt + 2 - ) - ], + (table.c.rgt >= parent_pos_right, table.c.rgt + 2), else_=table.c.rgt ) ) @@ -158,10 +139,8 @@ def mptt_before_delete(mapper, connection, instance, delete=True): table_pk = getattr(table.c, db_pk.name) lft, rgt = connection.execute( select( - [ - table.c.lft, - table.c.rgt - ] + table.c.lft, + table.c.rgt ).where( table_pk == pk ) @@ -171,7 +150,7 @@ def mptt_before_delete(mapper, connection, instance, delete=True): if delete: mapper.base_mapper.confirm_deleted_rows = False connection.execute( - table.delete( + table.delete().where( table_pk == pk ) ) @@ -190,28 +169,16 @@ def mptt_before_delete(mapper, connection, instance, delete=True): END """ connection.execute( - table.update( - and_( - table.c.rgt > rgt, - table.c.tree_id == tree_id - ) - ).values( + table.update() + .where(table.c.rgt > rgt) + .where(table.c.tree_id == tree_id) + .values( lft=case( - [ - ( - table.c.lft > lft, - table.c.lft - delta - ) - ], + (table.c.lft > lft, table.c.lft - delta), else_=table.c.lft ), rgt=case( - [ - ( - table.c.rgt >= rgt, - table.c.rgt - delta - ) - ], + (table.c.rgt >= rgt, table.c.rgt - delta), else_=table.c.rgt ) ) @@ -243,25 +210,21 @@ def mptt_before_update(mapper, connection, instance): right_sibling_tree_id ) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.parent_id, - table.c.level, - table.c.tree_id - ] + table.c.lft, + table.c.rgt, + table.c.parent_id, + table.c.level, + table.c.tree_id ).where( table_pk == instance.mptt_move_before ) ).fetchone() current_lvl_nodes = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.parent_id, - table.c.tree_id - ] + table.c.lft, + table.c.rgt, + table.c.parent_id, + table.c.tree_id ).where( and_( table.c.level == right_sibling_level, @@ -296,12 +259,10 @@ def mptt_before_update(mapper, connection, instance): left_sibling_tree_id ) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.parent_id, - table.c.tree_id - ] + table.c.lft, + table.c.rgt, + table.c.parent_id, + table.c.tree_id ).where( table_pk == instance.mptt_move_after ) @@ -320,7 +281,7 @@ def mptt_before_update(mapper, connection, instance): ORDER BY left_key """ subtree = connection.execute( - select([table_pk]) + select(table_pk) .where( and_( table.c.lft >= instance.left, @@ -345,13 +306,11 @@ def mptt_before_update(mapper, connection, instance): node_level ) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.tree_id, - table.c.parent_id, - table.c.level - ] + table.c.lft, + table.c.rgt, + table.c.tree_id, + table.c.parent_id, + table.c.level ).where( table_pk == node_id ) @@ -375,13 +334,11 @@ def mptt_before_update(mapper, connection, instance): parent_level ) = connection.execute( select( - [ - table_pk, - table.c.rgt, - table.c.lft, - table.c.tree_id, - table.c.level - ] + table_pk, + table.c.rgt, + table.c.lft, + table.c.tree_id, + table.c.level ).where( table_pk == instance.parent_id ) @@ -405,13 +362,11 @@ def mptt_before_update(mapper, connection, instance): parent_level ) = connection.execute( select( - [ - table_pk, - table.c.rgt, - table.c.lft, - table.c.tree_id, - table.c.level - ] + table_pk, + table.c.rgt, + table.c.lft, + table.c.tree_id, + table.c.level ).where( table_pk == instance.parent_id ) @@ -449,9 +404,9 @@ def mptt_before_update(mapper, connection, instance): if left_sibling_tree_id or left_sibling_tree_id == 0: tree_id = left_sibling_tree_id + 1 connection.execute( - table.update( - table.c.tree_id > left_sibling_tree_id - ).values( + table.update() + .where(table.c.tree_id > left_sibling_tree_id) + .values( tree_id=table.c.tree_id + 1 ) ) @@ -459,18 +414,14 @@ def mptt_before_update(mapper, connection, instance): else: tree_id = connection.scalar( select( - [ - func.max(table.c.tree_id) + 1 - ] + func.max(table.c.tree_id) + 1 ) ) connection.execute( - table.update( - table_pk.in_( - subtree - ) - ).values( + table.update() + .where(table_pk.in_(subtree)) + .values( lft=table.c.lft - node_pos_left + 1, rgt=table.c.rgt - node_pos_left + 1, level=table.c.level - node_level + default_level, diff --git a/sqlalchemy_mptt/tests/test_events.py b/sqlalchemy_mptt/tests/test_events.py index e4bb734..a41aa1c 100644 --- a/sqlalchemy_mptt/tests/test_events.py +++ b/sqlalchemy_mptt/tests/test_events.py @@ -14,13 +14,14 @@ from sqlalchemy import Column, Boolean, Integer, create_engine from sqlalchemy.event import contains -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy_mptt import mptt_sessionmaker -from . import TreeTestingMixin -from ..mixins import BaseNestedSets +from sqlalchemy_mptt.mixins import BaseNestedSets +from sqlalchemy_mptt.sqlalchemy_compat import declarative_base +from sqlalchemy_mptt.tests import TreeTestingMixin + Base = declarative_base() diff --git a/sqlalchemy_mptt/tests/test_inheritance.py b/sqlalchemy_mptt/tests/test_inheritance.py index 3434e93..3364a38 100644 --- a/sqlalchemy_mptt/tests/test_inheritance.py +++ b/sqlalchemy_mptt/tests/test_inheritance.py @@ -1,11 +1,12 @@ import unittest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -from . import TreeTestingMixin, failures_expected_on -from ..mixins import BaseNestedSets +from sqlalchemy_mptt.mixins import BaseNestedSets +from sqlalchemy_mptt.sqlalchemy_compat import declarative_base, get +from sqlalchemy_mptt.tests import TreeTestingMixin, failures_expected_on + Base = declarative_base() @@ -59,7 +60,7 @@ def test_create_generic(self): self.session.add(GenericTree(ppk=1)) self.session.commit() - tree = self.session.query(GenericTree).get(1) + tree = get(self.session, GenericTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) @@ -67,7 +68,7 @@ def test_create_spec(self): self.session.add(SpecializedTree(ppk=1)) self.session.commit() - tree = self.session.query(SpecializedTree).get(1) + tree = get(self.session, SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) @@ -83,21 +84,21 @@ def test_create_delete(self): self.session.add(parent) self.session.commit() - tree = self.session.query(SpecializedTree).get(1) + tree = get(self.session, SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) self.session.delete(child1) self.session.commit() - self.assertEqual(None, self.session.query(SpecializedTree).get(2)) + self.assertEqual(None, get(self.session, SpecializedTree, 2)) self.session.delete(child2) self.session.commit() - self.assertEqual(None, self.session.query(SpecializedTree).get(3)) - self.assertEqual(None, self.session.query(SpecializedTree).get(4)) - self.assertEqual(None, self.session.query(SpecializedTree).get(5)) + self.assertEqual(None, get(self.session, SpecializedTree, 3)) + self.assertEqual(None, get(self.session, SpecializedTree, 4)) + self.assertEqual(None, get(self.session, SpecializedTree, 5)) class TestGenericTree(TreeTestingMixin, unittest.TestCase): diff --git a/sqlalchemy_mptt/tests/test_mixins.py b/sqlalchemy_mptt/tests/test_mixins.py index 0bc1bf9..5f039d3 100644 --- a/sqlalchemy_mptt/tests/test_mixins.py +++ b/sqlalchemy_mptt/tests/test_mixins.py @@ -12,9 +12,9 @@ import unittest from sqlalchemy import Column, Integer -from sqlalchemy.ext.declarative import declarative_base -from ..mixins import BaseNestedSets +from sqlalchemy_mptt.mixins import BaseNestedSets +from sqlalchemy_mptt.sqlalchemy_compat import declarative_base Base = declarative_base() diff --git a/sqlalchemy_mptt/tests/test_stateful.py b/sqlalchemy_mptt/tests/test_stateful.py index 4003e74..b2bc286 100644 --- a/sqlalchemy_mptt/tests/test_stateful.py +++ b/sqlalchemy_mptt/tests/test_stateful.py @@ -8,10 +8,10 @@ from hypothesis import HealthCheck, settings, strategies as st from hypothesis.stateful import Bundle, RuleBasedStateMachine, consumes, invariant, rule from sqlalchemy import Column, Integer, Boolean, create_engine -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import joinedload, sessionmaker from sqlalchemy_mptt import BaseNestedSets, mptt_sessionmaker +from sqlalchemy_mptt.sqlalchemy_compat import declarative_base Base = declarative_base() @@ -62,8 +62,9 @@ def delete_node(self, node): @rule(target=node, node=node, visible=st.none() | st.booleans()) def add_child(self, node, visible): - child = Tree(parent=node, visible=visible) - self.session.add(child) + # Avoid cascade_backrefs here since it is deprecated. + child = Tree(visible=visible) + node.children.append(child) self.session.commit() assert node.left < child.left < child.right < node.right return child From 8e55e1387487654da12a9016ff213233fa038350 Mon Sep 17 00:00:00 2001 From: Fayaz Yusuf Khan Date: Sun, 22 Jun 2025 16:02:06 -0400 Subject: [PATCH 5/6] Add copyright notice --- sqlalchemy_mptt/events.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlalchemy_mptt/events.py b/sqlalchemy_mptt/events.py index a825a21..69dbe4c 100644 --- a/sqlalchemy_mptt/events.py +++ b/sqlalchemy_mptt/events.py @@ -3,6 +3,7 @@ # vim:fenc=utf-8 # # Copyright © 2014 uralbash +# Copyright (c) 2025 Fayaz Yusuf Khan # # Distributed under terms of the MIT license. From 710656fd00511ca7f213705d681cff05f14fed09 Mon Sep 17 00:00:00 2001 From: Fayaz Yusuf Khan Date: Sun, 22 Jun 2025 16:57:09 -0400 Subject: [PATCH 6/6] Refactor version detection logic --- sqlalchemy_mptt/events.py | 34 +++++++------- sqlalchemy_mptt/sqlalchemy_compat.py | 57 ++++++++++++++--------- sqlalchemy_mptt/tests/test_events.py | 4 +- sqlalchemy_mptt/tests/test_inheritance.py | 20 ++++---- sqlalchemy_mptt/tests/test_mixins.py | 4 +- sqlalchemy_mptt/tests/test_stateful.py | 4 +- 6 files changed, 69 insertions(+), 54 deletions(-) diff --git a/sqlalchemy_mptt/events.py b/sqlalchemy_mptt/events.py index 69dbe4c..718c315 100644 --- a/sqlalchemy_mptt/events.py +++ b/sqlalchemy_mptt/events.py @@ -19,7 +19,7 @@ from sqlalchemy.sql import func from sqlalchemy.orm.base import NO_VALUE -from sqlalchemy_mptt.sqlalchemy_compat import case, select +from sqlalchemy_mptt.sqlalchemy_compat import compat_layer def _insert_subtree( @@ -62,7 +62,7 @@ def _insert_subtree( .where(table.c.tree_id == parent_tree_id) .values( rgt=table.c.rgt + node_size, - lft=case( + lft=compat_layer.case( (table.c.lft > left_sibling['lft'], table.c.lft + node_size), else_=table.c.lft ) @@ -89,7 +89,7 @@ def mptt_before_insert(mapper, connection, instance): instance.right = 2 instance.level = instance.get_default_level() tree_id = connection.scalar( - select( + compat_layer.select( func.max(table.c.tree_id) + 1 ) ) or 1 @@ -99,7 +99,7 @@ def mptt_before_insert(mapper, connection, instance): parent_pos_right, parent_tree_id, parent_level) = connection.execute( - select( + compat_layer.select( table.c.lft, table.c.rgt, table.c.tree_id, @@ -115,11 +115,11 @@ def mptt_before_insert(mapper, connection, instance): .where(table.c.rgt >= parent_pos_right) .where(table.c.tree_id == parent_tree_id) .values( - lft=case( + lft=compat_layer.case( (table.c.lft > parent_pos_right, table.c.lft + 2), else_=table.c.lft ), - rgt=case( + rgt=compat_layer.case( (table.c.rgt >= parent_pos_right, table.c.rgt + 2), else_=table.c.rgt ) @@ -139,7 +139,7 @@ def mptt_before_delete(mapper, connection, instance, delete=True): db_pk = instance.get_pk_column() table_pk = getattr(table.c, db_pk.name) lft, rgt = connection.execute( - select( + compat_layer.select( table.c.lft, table.c.rgt ).where( @@ -174,11 +174,11 @@ def mptt_before_delete(mapper, connection, instance, delete=True): .where(table.c.rgt > rgt) .where(table.c.tree_id == tree_id) .values( - lft=case( + lft=compat_layer.case( (table.c.lft > lft, table.c.lft - delta), else_=table.c.lft ), - rgt=case( + rgt=compat_layer.case( (table.c.rgt >= rgt, table.c.rgt - delta), else_=table.c.rgt ) @@ -210,7 +210,7 @@ def mptt_before_update(mapper, connection, instance): right_sibling_level, right_sibling_tree_id ) = connection.execute( - select( + compat_layer.select( table.c.lft, table.c.rgt, table.c.parent_id, @@ -221,7 +221,7 @@ def mptt_before_update(mapper, connection, instance): ) ).fetchone() current_lvl_nodes = connection.execute( - select( + compat_layer.select( table.c.lft, table.c.rgt, table.c.parent_id, @@ -259,7 +259,7 @@ def mptt_before_update(mapper, connection, instance): left_sibling_parent, left_sibling_tree_id ) = connection.execute( - select( + compat_layer.select( table.c.lft, table.c.rgt, table.c.parent_id, @@ -282,7 +282,7 @@ def mptt_before_update(mapper, connection, instance): ORDER BY left_key """ subtree = connection.execute( - select(table_pk) + compat_layer.select(table_pk) .where( and_( table.c.lft >= instance.left, @@ -306,7 +306,7 @@ def mptt_before_update(mapper, connection, instance): node_parent_id, node_level ) = connection.execute( - select( + compat_layer.select( table.c.lft, table.c.rgt, table.c.tree_id, @@ -334,7 +334,7 @@ def mptt_before_update(mapper, connection, instance): parent_tree_id, parent_level ) = connection.execute( - select( + compat_layer.select( table_pk, table.c.rgt, table.c.lft, @@ -362,7 +362,7 @@ def mptt_before_update(mapper, connection, instance): parent_tree_id, parent_level ) = connection.execute( - select( + compat_layer.select( table_pk, table.c.rgt, table.c.lft, @@ -414,7 +414,7 @@ def mptt_before_update(mapper, connection, instance): # if just insert else: tree_id = connection.scalar( - select( + compat_layer.select( func.max(table.c.tree_id) + 1 ) ) diff --git a/sqlalchemy_mptt/sqlalchemy_compat.py b/sqlalchemy_mptt/sqlalchemy_compat.py index a4083b2..24cd090 100644 --- a/sqlalchemy_mptt/sqlalchemy_compat.py +++ b/sqlalchemy_mptt/sqlalchemy_compat.py @@ -7,34 +7,49 @@ import sqlalchemy as sa -if sa.__version__ < '1.4': - from sqlalchemy.ext.declarative import declarative_base -else: - from sqlalchemy.orm import declarative_base +class LegacySQLAlchemyAPI: + """A class to provide compatibility for legacy SQLAlchemy versions (1.0 - 1.3).""" + @staticmethod + def declarative_base(*args, **kwargs): + from sqlalchemy.ext.declarative import declarative_base + return declarative_base(*args, **kwargs) -def select(*args, **kwargs): - """Compatibility function for select.""" - if sa.__version__ < '1.4': + @staticmethod + def select(*args, **kwargs): return sa.select(args, **kwargs) - else: - return sa.select(*args, **kwargs) - -def case(*args, **kwargs): - """Compatibility function for case.""" - if sa.__version__ < '1.4': + @staticmethod + def case(*args, **kwargs): return sa.case(args, **kwargs) - else: - return sa.case(*args, **kwargs) - -def get(session, model, id): - """Compatibility function for getting an object by ID.""" - if sa.__version__ < '1.4': + @staticmethod + def get(session, model, id): return session.query(model).get(id) - else: + + +class ModernSQLAlchemyAPI: + """A class to provide compatibility for modern SQLAlchemy versions (1.4+).""" + + @staticmethod + def declarative_base(*args, **kwargs): + from sqlalchemy.orm import declarative_base + return declarative_base(*args, **kwargs) + + @staticmethod + def select(*args, **kwargs): + return sa.select(*args, **kwargs) + + @staticmethod + def case(*args, **kwargs): + return sa.case(*args, **kwargs) + + @staticmethod + def get(session, model, id): return session.get(model, id) -__all__ = ["case", "declarative_base", "select"] +if sa.__version__ < '1.4': + compat_layer = LegacySQLAlchemyAPI() +else: + compat_layer = ModernSQLAlchemyAPI() diff --git a/sqlalchemy_mptt/tests/test_events.py b/sqlalchemy_mptt/tests/test_events.py index a41aa1c..f1c80e3 100644 --- a/sqlalchemy_mptt/tests/test_events.py +++ b/sqlalchemy_mptt/tests/test_events.py @@ -19,11 +19,11 @@ from sqlalchemy_mptt import mptt_sessionmaker from sqlalchemy_mptt.mixins import BaseNestedSets -from sqlalchemy_mptt.sqlalchemy_compat import declarative_base +from sqlalchemy_mptt.sqlalchemy_compat import compat_layer from sqlalchemy_mptt.tests import TreeTestingMixin -Base = declarative_base() +Base = compat_layer.declarative_base() class Tree(Base, BaseNestedSets): diff --git a/sqlalchemy_mptt/tests/test_inheritance.py b/sqlalchemy_mptt/tests/test_inheritance.py index 3364a38..1ac271d 100644 --- a/sqlalchemy_mptt/tests/test_inheritance.py +++ b/sqlalchemy_mptt/tests/test_inheritance.py @@ -4,11 +4,11 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy_mptt.mixins import BaseNestedSets -from sqlalchemy_mptt.sqlalchemy_compat import declarative_base, get +from sqlalchemy_mptt.sqlalchemy_compat import compat_layer from sqlalchemy_mptt.tests import TreeTestingMixin, failures_expected_on -Base = declarative_base() +Base = compat_layer.declarative_base() class GenericTree(Base, BaseNestedSets): @@ -60,7 +60,7 @@ def test_create_generic(self): self.session.add(GenericTree(ppk=1)) self.session.commit() - tree = get(self.session, GenericTree, 1) + tree = compat_layer.get(self.session, GenericTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) @@ -68,7 +68,7 @@ def test_create_spec(self): self.session.add(SpecializedTree(ppk=1)) self.session.commit() - tree = get(self.session, SpecializedTree, 1) + tree = compat_layer.get(self.session, SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) @@ -84,21 +84,21 @@ def test_create_delete(self): self.session.add(parent) self.session.commit() - tree = get(self.session, SpecializedTree, 1) + tree = compat_layer.get(self.session, SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) self.session.delete(child1) self.session.commit() - self.assertEqual(None, get(self.session, SpecializedTree, 2)) + self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 2)) self.session.delete(child2) self.session.commit() - self.assertEqual(None, get(self.session, SpecializedTree, 3)) - self.assertEqual(None, get(self.session, SpecializedTree, 4)) - self.assertEqual(None, get(self.session, SpecializedTree, 5)) + self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 3)) + self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 4)) + self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 5)) class TestGenericTree(TreeTestingMixin, unittest.TestCase): @@ -116,7 +116,7 @@ def test_rebuild(self): super().test_rebuild() -Base2 = declarative_base() +Base2 = compat_layer.declarative_base() class BaseInheritance(Base2): diff --git a/sqlalchemy_mptt/tests/test_mixins.py b/sqlalchemy_mptt/tests/test_mixins.py index 5f039d3..cd523df 100644 --- a/sqlalchemy_mptt/tests/test_mixins.py +++ b/sqlalchemy_mptt/tests/test_mixins.py @@ -14,10 +14,10 @@ from sqlalchemy import Column, Integer from sqlalchemy_mptt.mixins import BaseNestedSets -from sqlalchemy_mptt.sqlalchemy_compat import declarative_base +from sqlalchemy_mptt.sqlalchemy_compat import compat_layer -Base = declarative_base() +Base = compat_layer.declarative_base() class Tree2(Base, BaseNestedSets): diff --git a/sqlalchemy_mptt/tests/test_stateful.py b/sqlalchemy_mptt/tests/test_stateful.py index b2bc286..be93705 100644 --- a/sqlalchemy_mptt/tests/test_stateful.py +++ b/sqlalchemy_mptt/tests/test_stateful.py @@ -11,10 +11,10 @@ from sqlalchemy.orm import joinedload, sessionmaker from sqlalchemy_mptt import BaseNestedSets, mptt_sessionmaker -from sqlalchemy_mptt.sqlalchemy_compat import declarative_base +from sqlalchemy_mptt.sqlalchemy_compat import compat_layer -Base = declarative_base() +Base = compat_layer.declarative_base() class Tree(Base, BaseNestedSets):