From b174ff7835e83dd21b73593dc9e4078565e01caa Mon Sep 17 00:00:00 2001 From: Semyon Pupkov Date: Thu, 11 Dec 2025 12:11:01 +0500 Subject: [PATCH] Support skip_autocommit_rollback option for sqlalchemy --- .../phoenixdb/sqlalchemy_phoenix.py | 3 ++ .../phoenixdb/tests/test_sqlalchemy.py | 41 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/python-phoenixdb/phoenixdb/sqlalchemy_phoenix.py b/python-phoenixdb/phoenixdb/sqlalchemy_phoenix.py index 3df83ef..4b2e140 100644 --- a/python-phoenixdb/phoenixdb/sqlalchemy_phoenix.py +++ b/python-phoenixdb/phoenixdb/sqlalchemy_phoenix.py @@ -141,6 +141,9 @@ def create_connect_args(self, url): )) return [phoenix_url], connect_args + def detect_autocommit_setting(self, dbapi_conn): + return bool(dbapi_conn.autocommit) + def has_table(self, connection, table_name, schema=None, **kw): if schema is None: schema = '' diff --git a/python-phoenixdb/phoenixdb/tests/test_sqlalchemy.py b/python-phoenixdb/phoenixdb/tests/test_sqlalchemy.py index ee87c20..4649c58 100644 --- a/python-phoenixdb/phoenixdb/tests/test_sqlalchemy.py +++ b/python-phoenixdb/phoenixdb/tests/test_sqlalchemy.py @@ -15,6 +15,7 @@ import sys import unittest +from unittest import mock import sqlalchemy as db from sqlalchemy import text @@ -41,6 +42,40 @@ def test_connection(self): catalog = db.Table('CATALOG', metadata, schema='SYSTEM', autoload_with=engine) self.assertIn('TABLE_NAME', catalog.columns.keys()) + def test_set_autocommit(self): + engine = self._create_engine() + with engine.connect() as conn: + self.assertFalse(conn.connection.connection.autocommit) + + engine = self._create_engine(extra_connect_args={"autoCommit": True}) + with engine.connect() as conn: + self.assertTrue(conn.connection.connection.autocommit) + + @unittest.skipIf(db.__version__ < "2.0.43", "skip_autocommit_rollback added in 2.0.43") + def test_skip_autocommit_rollback_enabled(self): + engine = self._create_engine( + extra_connect_args={"autoCommit": True}, + skip_autocommit_rollback=True + ) + with engine.connect() as conn: + self.assertTrue(conn.connection.connection.autocommit) + + client = conn.connection.connection._client + with mock.patch.object(client, "rollback", wraps=client.rollback) as check_rollback: + conn.close() + self.assertEqual(len(check_rollback.mock_calls), 0) + + @unittest.skipIf(db.__version__ < "2.0.43", "skip_autocommit_rollback added in 2.0.43") + def test_skip_autocommit_rollback_disabled(self): + engine = self._create_engine(extra_connect_args={"autoCommit": True}) + with engine.connect() as conn: + self.assertTrue(conn.connection.connection.autocommit) + + client = conn.connection.connection._client + with mock.patch.object(client, "rollback", wraps=client.rollback) as check_rollback: + conn.close() + self.assertEqual(len(check_rollback.mock_calls), 1) + def test_textual(self): engine = self._create_engine() with engine.connect() as connection: @@ -154,7 +189,7 @@ def test_reflection(self): def test_orm(self): pass - def _create_engine(self): + def _create_engine(self, extra_connect_args=None, **kw): ''''Massage the properties that we use for the DBAPI tests so that they apply to SQLAlchemy''' @@ -173,5 +208,7 @@ def _create_engine(self): connect_args.update(avatica_password=TEST_DB_AVATICA_PASSWORD) if TEST_DB_TRUSTSTORE: connect_args.update(trustore=TEST_DB_TRUSTSTORE) + if extra_connect_args: + connect_args.update(**extra_connect_args) - return db.create_engine(urlunparse(url_parts), tls=tls, connect_args=connect_args) + return db.create_engine(urlunparse(url_parts), tls=tls, connect_args=connect_args, **kw)