diff --git a/djongo/base.py b/djongo/base.py index ba6766d7..e10b847f 100644 --- a/djongo/base.py +++ b/djongo/base.py @@ -1,6 +1,10 @@ """ MongoDB database backend for Django """ + +# THIS FILE WAS CHANGED ON - 05 Sep 2022 + +import traceback from collections import OrderedDict from logging import getLogger from django.db.backends.base.base import BaseDatabaseWrapper @@ -10,10 +14,12 @@ from .creation import DatabaseCreation from . import database as Database from .cursor import Cursor +from .database import DatabaseError from .features import DatabaseFeatures from .introspection import DatabaseIntrospection from .operations import DatabaseOperations from .schema import DatabaseSchemaEditor +from .transaction import Transaction logger = getLogger(__name__) @@ -34,9 +40,10 @@ def __contains__(self, item): class DjongoClient: - def __init__(self, database, enforce_schema=True): + def __init__(self, database, enforce_schema=True, session=None): self.enforce_schema = enforce_schema self.cached_collections = CachedCollections(database) + self.session = session class DatabaseWrapper(BaseDatabaseWrapper): @@ -115,6 +122,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): def __init__(self, *args, **kwargs): self.client_connection = None self.djongo_connection = None + self.transaction = None + self.rollbacked = False super().__init__(*args, **kwargs) def is_usable(self): @@ -187,7 +196,14 @@ def _set_autocommit(self, autocommit): TODO: For future reference, setting two phase commits and rollbacks might require populating this method. """ - pass + self.autocommit = False + + def set_autocommit( + self, autocommit, force_begin_transaction_with_broken_autocommit=False + ): + result = super().set_autocommit(autocommit, force_begin_transaction_with_broken_autocommit=False) + self.autocommit = False + return result def init_connection_state(self): try: @@ -220,3 +236,34 @@ def _commit(self): TODO: two phase commits are not supported yet. """ pass + + def _savepoint(self, sid): + # add _savepoint method to work with Django's transactions + self.in_atomic_block = True + self.transaction = Transaction(self.client_connection) + connection_params = self.get_connection_params() + + name = connection_params.pop('name') + + # this will be used in sql2mongo/query.py as session parameter when using pymongo CRUD operations + self.djongo_connection.session = self.transaction.session + # this will be used in models/fields.py as session parameter when using pymongo CRUD operations + # in that file pymongo functions are prefixed with 'mongo_' + self.client_connection[name].__setattr__('session', self.transaction.session) + + + def _savepoint_commit(self, sid): + # add _savepoint_commit method to work with Django's transactions + # this method is executed even if rollback is executed after it + connection_params = self.get_connection_params() + name = connection_params.pop('name') + if not self.rollbacked: + self.transaction.__exit__(None, None, traceback) + self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema')) + self.client_connection[name].__setattr__('session', None) + + def _savepoint_rollback(self, sid): + self.rollbacked = True + # We have to pass in some error, but it is not used anywhere as far as known + self.transaction.__exit__('DatabaseError', DatabaseError('Error in transaction; rollbacked'), traceback) + diff --git a/djongo/cursor.py b/djongo/cursor.py index f5a4f4b8..c51176ee 100644 --- a/djongo/cursor.py +++ b/djongo/cursor.py @@ -1,3 +1,5 @@ +# THIS FILE WAS CHANGED ON - 05 Sep 2022 + from logging import getLogger from .database import DatabaseError @@ -55,7 +57,7 @@ def execute(self, sql, params=None): sql, params) except Exception as e: - db_exe = DatabaseError() + db_exe = DatabaseError(str(e)) raise db_exe from e def fetchmany(self, size=1): diff --git a/djongo/features.py b/djongo/features.py index acfee1e3..e63d447a 100644 --- a/djongo/features.py +++ b/djongo/features.py @@ -1,3 +1,5 @@ +# THIS FILE WAS CHANGED ON - 05 Sep 2022 + from django.db.backends.base.features import BaseDatabaseFeatures @@ -7,7 +9,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_bulk_insert = True has_native_uuid_field = True supports_timezones = False - uses_savepoints = False + uses_savepoints = True can_clone_databases = True test_db_allows_multiple_connections = False supports_unspecified_pk = True diff --git a/djongo/models/fields.py b/djongo/models/fields.py index d34d8aa2..b21b584b 100644 --- a/djongo/models/fields.py +++ b/djongo/models/fields.py @@ -13,7 +13,7 @@ These are the main fields for working with MongoDB. """ -# THIS FILE WAS CHANGED ON - 28 Mar 2022 +# THIS FILE WAS CHANGED ON - 05 Sep 2022 import functools import json @@ -781,7 +781,8 @@ def add(self, *objs): lh_field.get_attname(): getattr(self.instance, rh_field.get_attname()) } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) for obj in objs: fk_field = getattr(obj, lh_field.get_attname()) @@ -877,7 +878,8 @@ def add(self, *objs): '$each': list(new_fks) } } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) add.alters_data = True @@ -903,7 +905,8 @@ def _remove(self, to_del): '$in': list(to_del) } } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) def clear(self): @@ -914,7 +917,8 @@ def clear(self): '$set': { self.field.attname: [] } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) setattr(self.instance, self.field.attname, set()) diff --git a/djongo/sql2mongo/operators.py b/djongo/sql2mongo/operators.py index bb3bc422..93c0bdfc 100644 --- a/djongo/sql2mongo/operators.py +++ b/djongo/sql2mongo/operators.py @@ -52,6 +52,8 @@ def __init__( self.is_negated = False self._name = name self.precedence = OPERATOR_PRECEDENCE[name] + + self._options = {} def negate(self): raise NotImplementedError @@ -508,6 +510,22 @@ def to_mongo(self): return self._op.to_mongo() +class CaseOp(_Op, _StatementParser): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.statement.skip(2) + self._statement2ops() + self.evaluate() + + def negate(self): + raise NotImplementedError + + def to_mongo(self): + return self._op.to_mongo() + + class ParenthesisOp(_Op, _StatementParser): def to_mongo(self): @@ -525,6 +543,7 @@ def negate(self): class CmpOp(_Op): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._identifier = SQLToken.token2sql(self.statement.left, self.query) @@ -532,7 +551,8 @@ def __init__(self, *args, **kwargs): if isinstance(self.statement.right, Identifier): raise SQLDecodeError('Join using WHERE not supported') - self._operator = OPERATOR_MAP[self.statement.token_next(0)[1].value] + sql_operation = self.statement.token_next(0)[1].value + self._operator = OPERATOR_MAP[sql_operation] index = re_index(self.statement.right.value) if self._operator in NEW_OPERATORS: @@ -541,6 +561,12 @@ def __init__(self, *args, **kwargs): else: self._constant = self.params[index] if index is not None else MAP_INDEX_NONE[self.statement.right.value] + if sql_operation == 'iLIKE': + self._options = {'$options': 'im'} + self._constant = self._make_regex(self._constant) + if sql_operation == 'LIKE': + self._constant = self._make_regex(self._constant) + if isinstance(self._constant, dict): self._field_ext, self._constant = next(iter(self._constant.items())) else: @@ -558,9 +584,40 @@ def to_mongo(self): field += '.' + self._field_ext if not self.is_negated: - return {field: {self._operator: self._constant}} + return {field: {self._operator: self._constant, **self._options}} + else: + return {field: {'$not': {self._operator: self._constant, **self._options}}} + + @staticmethod + def _check_embedded(to_match): + try: + check_dict = to_match + replace_chars = "\\%'" + for c in replace_chars: + if c == "'": + check_dict = check_dict.replace("'", '"') + else: + check_dict = check_dict.replace(c, "") + check_dict = json.loads(check_dict) + if isinstance(check_dict, dict): + return check_dict + else: + return to_match + except Exception as e: + return to_match + + @staticmethod + def _make_regex(to_match): + to_match = CmpOp._check_embedded(to_match) + if isinstance(to_match, str): + to_match = to_match.replace('%', '.*') + regex = '^' + to_match + '$' + elif isinstance(to_match, dict): + field_ext, to_match = next(iter(to_match.items())) + regex = to_match else: - return {field: {'$not': {self._operator: self._constant}}} + raise SQLDecodeError + return regex OPERATOR_MAP = { @@ -570,7 +627,9 @@ def to_mongo(self): '>=': '$gte', '<=': '$lte', 'IN': '$in', - 'NOT IN': '$nin' + 'NOT IN': '$nin', + 'LIKE': '$regex', + 'iLIKE': '$regex', } OPERATOR_PRECEDENCE = { 'IS': 8, diff --git a/djongo/sql2mongo/query.py b/djongo/sql2mongo/query.py index 465c6879..4a634bd6 100644 --- a/djongo/sql2mongo/query.py +++ b/djongo/sql2mongo/query.py @@ -3,7 +3,7 @@ SQL constructors. """ -# THIS FILE WAS CHANGED ON - 19 Aug 2022 +# THIS FILE WAS CHANGED ON - 05 Sep 2022 import abc import re @@ -238,7 +238,7 @@ def _needs_column_selection(self): def _get_cursor(self): if self._needs_aggregation(): pipeline = self._make_pipeline() - cur = self.db[self.left_table].aggregate(pipeline) + cur = self.db[self.left_table].aggregate(pipeline, session=self.connection_properties.session) logger.debug(f'Aggregation query: {pipeline}') else: kwargs = {} @@ -257,7 +257,7 @@ def _get_cursor(self): if self.offset: kwargs.update(self.offset.to_mongo()) - cur = self.db[self.left_table].find(**kwargs) + cur = self.db[self.left_table].find(**kwargs, session=self.connection_properties.session) logger.debug(f'Find query: {kwargs}') return cur @@ -330,7 +330,7 @@ def parse(self): def execute(self): db = self.db - self.result = db[self.left_table].update_many(**self.kwargs) + self.result = db[self.left_table].update_many(**self.kwargs , session=self.connection_properties.session) logger.debug(f'update_many: {self.result.modified_count}, matched: {self.result.matched_count}') @@ -388,7 +388,8 @@ def execute(self): } }, {'$inc': {'auto.seq': num}}, - return_document=ReturnDocument.AFTER + return_document=ReturnDocument.AFTER, + session=self.connection_properties.session ) for i, val in enumerate(self._values): @@ -403,7 +404,7 @@ def execute(self): ins[_field] = value docs.append(ins) - res = self.db[self.left_table].insert_many(docs, ordered=False) + res = self.db[self.left_table].insert_many(docs, ordered=False, session=self.connection_properties.session) if auto: self._result_ref.last_row_id = auto['auto']['seq'] else: @@ -479,11 +480,12 @@ def _rename_column(self): '$rename': { self._old_name: self._new_name } - } + }, + session=self.connection_properties.session ) def _rename_collection(self): - self.db[self.left_table].rename(self._new_name) + self.db[self.left_table].rename(self._new_name, session=self.connection_properties.session) def _alter(self, statement: SQLStatement): self.execute = lambda: None @@ -510,7 +512,7 @@ def _alter(self, statement: SQLStatement): print_warn(feature) def _flush(self): - self.db[self.left_table].delete_many({}) + self.db[self.left_table].delete_many({}, session=self.connection_properties.session) def _table(self, statement: SQLStatement): tok = statement.next() @@ -535,7 +537,7 @@ def _drop(self, statement: SQLStatement): raise SQLDecodeError def _drop_index(self): - self.db[self.left_table].drop_index(self._iden_name) + self.db[self.left_table].drop_index(self._iden_name, session=self.connection_properties.session) def _drop_column(self): self.db[self.left_table].update_many( @@ -544,7 +546,8 @@ def _drop_column(self): '$unset': { self._iden_name: '' } - } + }, + session=self.connection_properties.session ) self.db['__schema__'].update_one( {'name': self.left_table}, @@ -552,7 +555,8 @@ def _drop_column(self): '$unset': { f'fields.{self._iden_name}': '' } - } + }, + session=self.connection_properties.session ) def _add(self, statement: SQLStatement): @@ -618,7 +622,8 @@ def _add_column(self): '$set': { self._iden_name: self._default } - } + }, + session=self.connection_properties.session ) self.db['__schema__'].update_one( {'name': self.left_table}, @@ -628,19 +633,24 @@ def _add_column(self): 'type_code': self._type_code } } - } + }, + session=self.connection_properties.session ) def _index(self): self.db[self.left_table].create_index( self.field_dir, - name=self._iden_name) + name=self._iden_name, + session=self.connection_properties.session + ) def _unique(self): self.db[self.left_table].create_index( self.field_dir, unique=True, - name=self._iden_name) + name=self._iden_name, + session=self.connection_properties.session + ) def _fk(self): pass @@ -653,15 +663,15 @@ def __init__(self, *args): def _create_table(self, statement): if '__schema__' not in self.connection_properties.cached_collections: - self.db.create_collection('__schema__') + self.db.create_collection('__schema__', session=self.connection_properties.session) self.connection_properties.cached_collections.add('__schema__') - self.db['__schema__'].create_index('name', unique=True) - self.db['__schema__'].create_index('auto') + self.db['__schema__'].create_index('name', unique=True, session=self.connection_properties.session) + self.db['__schema__'].create_index('auto', session=self.connection_properties.session) tok = statement.next() table = SQLToken.token2sql(tok, self).table try: - self.db.create_collection(table) + self.db.create_collection(table, session=self.connection_properties.session) except CollectionInvalid: if self.connection_properties.enforce_schema: raise @@ -708,10 +718,12 @@ def _create_table(self, statement): _set['auto.seq'] = 0 if SQLColumnDef.primarykey in col.col_constraints: - self.db[table].create_index(field, unique=True, name='__primary_key__') + self.db[table].create_index(field, unique=True, name='__primary_key__', + session=self.connection_properties.session) if SQLColumnDef.unique in col.col_constraints: - self.db[table].create_index(field, unique=True) + self.db[table].create_index(field, unique=True, + session=self.connection_properties.session) if (SQLColumnDef.not_null in col.col_constraints or SQLColumnDef.null in col.col_constraints): @@ -725,7 +737,8 @@ def _create_table(self, statement): self.db['__schema__'].update_one( filter=_filter, update=update, - upsert=True + upsert=True, + session=self.connection_properties.session, ) def parse(self): @@ -762,7 +775,7 @@ def parse(self): def execute(self): db_con = self.db - self.result = db_con[self.left_table].delete_many(**self.kw) + self.result = db_con[self.left_table].delete_many(**self.kw, session=self.connection_properties.session) logger.debug('delete_many: {}'.format(self.result.deleted_count)) def count(self): @@ -976,7 +989,7 @@ def _drop(self, sm): elif tok.match(tokens.Keyword, 'TABLE'): tok = statement.next() table_name = tok.get_name() - self.db.drop_collection(table_name) + self.db.drop_collection(table_name, session=self.connection_properties.session) else: raise SQLDecodeError('statement:{}'.format(sm)) diff --git a/djongo/sql2mongo/sql_tokens.py b/djongo/sql2mongo/sql_tokens.py index 81d92e9f..a10c46e1 100644 --- a/djongo/sql2mongo/sql_tokens.py +++ b/djongo/sql2mongo/sql_tokens.py @@ -130,7 +130,8 @@ def table(self) -> str: name = self.given_table alias2token = self.token_alias.alias2token try: - return alias2token[name].table + if not self.is_explicit_alias(): + return alias2token[name].table except KeyError: return name diff --git a/djongo/transaction.py b/djongo/transaction.py index 0ce03de5..715eedcb 100644 --- a/djongo/transaction.py +++ b/djongo/transaction.py @@ -1,5 +1,25 @@ -from djongo.exceptions import NotSupportedError -from djongo import djongo_access_url +# THIS FILE WAS CHANGED ON - 05 Sep 2022 -print(f'This version of djongo does not support transactions. Visit {djongo_access_url}') -raise NotSupportedError('transactions') +from pymongo import WriteConcern +from pymongo.read_concern import ReadConcern + + +class Transaction: + def __init__(self, mongo_client): + # do initial steps for transaction as noted in mongo documentation for database version 4.0 + self.mongo_client = mongo_client + self.session = self.mongo_client.start_session().__enter__() + self.rollbacked = False + + self.transaction = self.session.start_transaction( + read_concern=ReadConcern("snapshot"), + write_concern=WriteConcern(w="majority")) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + # cannot retry bacause documentDB doesn't support retryable writes + # exit transaction decorators as noted in mongo documentation for database version 4.0 + self.transaction.__exit__(exc_type, exc_val, traceback) + self.session.__exit__(exc_type, exc_val, traceback)