Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions djongo/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
MongoDB database backend for Django
"""

# THIS FILE WAS CHANGED ON - 05 Sep 2022

import traceback
from collections import OrderedDict
from logging import getLogger
from django.db.backends.base.base import BaseDatabaseWrapper
Expand All @@ -10,10 +14,12 @@
from .creation import DatabaseCreation
from . import database as Database
from .cursor import Cursor
from .database import DatabaseError
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
from .transaction import Transaction

logger = getLogger(__name__)

Expand All @@ -34,9 +40,10 @@ def __contains__(self, item):

class DjongoClient:

def __init__(self, database, enforce_schema=True):
def __init__(self, database, enforce_schema=True, session=None):
self.enforce_schema = enforce_schema
self.cached_collections = CachedCollections(database)
self.session = session


class DatabaseWrapper(BaseDatabaseWrapper):
Expand Down Expand Up @@ -115,6 +122,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs):
self.client_connection = None
self.djongo_connection = None
self.transaction = None
self.rollbacked = False
super().__init__(*args, **kwargs)

def is_usable(self):
Expand Down Expand Up @@ -187,7 +196,14 @@ def _set_autocommit(self, autocommit):
TODO: For future reference, setting two phase commits and rollbacks
might require populating this method.
"""
pass
self.autocommit = False

def set_autocommit(
self, autocommit, force_begin_transaction_with_broken_autocommit=False
):
result = super().set_autocommit(autocommit, force_begin_transaction_with_broken_autocommit=False)
self.autocommit = False
return result

def init_connection_state(self):
try:
Expand Down Expand Up @@ -220,3 +236,34 @@ def _commit(self):
TODO: two phase commits are not supported yet.
"""
pass

def _savepoint(self, sid):
# add _savepoint method to work with Django's transactions
self.in_atomic_block = True
self.transaction = Transaction(self.client_connection)
connection_params = self.get_connection_params()

name = connection_params.pop('name')

# this will be used in sql2mongo/query.py as session parameter when using pymongo CRUD operations
self.djongo_connection.session = self.transaction.session
# this will be used in models/fields.py as session parameter when using pymongo CRUD operations
# in that file pymongo functions are prefixed with 'mongo_'
self.client_connection[name].__setattr__('session', self.transaction.session)


def _savepoint_commit(self, sid):
# add _savepoint_commit method to work with Django's transactions
# this method is executed even if rollback is executed after it
connection_params = self.get_connection_params()
name = connection_params.pop('name')
if not self.rollbacked:
self.transaction.__exit__(None, None, traceback)
self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema'))
self.client_connection[name].__setattr__('session', None)

def _savepoint_rollback(self, sid):
self.rollbacked = True
# We have to pass in some error, but it is not used anywhere as far as known
self.transaction.__exit__('DatabaseError', DatabaseError('Error in transaction; rollbacked'), traceback)

4 changes: 3 additions & 1 deletion djongo/cursor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# THIS FILE WAS CHANGED ON - 05 Sep 2022

from logging import getLogger

from .database import DatabaseError
Expand Down Expand Up @@ -55,7 +57,7 @@ def execute(self, sql, params=None):
sql,
params)
except Exception as e:
db_exe = DatabaseError()
db_exe = DatabaseError(str(e))
raise db_exe from e

def fetchmany(self, size=1):
Expand Down
4 changes: 3 additions & 1 deletion djongo/features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# THIS FILE WAS CHANGED ON - 05 Sep 2022

from django.db.backends.base.features import BaseDatabaseFeatures


Expand All @@ -7,7 +9,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_bulk_insert = True
has_native_uuid_field = True
supports_timezones = False
uses_savepoints = False
uses_savepoints = True
can_clone_databases = True
test_db_allows_multiple_connections = False
supports_unspecified_pk = True
Expand Down
14 changes: 9 additions & 5 deletions djongo/models/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
These are the main fields for working with MongoDB.
"""

# THIS FILE WAS CHANGED ON - 28 Mar 2022
# THIS FILE WAS CHANGED ON - 05 Sep 2022

import functools
import json
Expand Down Expand Up @@ -781,7 +781,8 @@ def add(self, *objs):
lh_field.get_attname():
getattr(self.instance, rh_field.get_attname())
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)
for obj in objs:
fk_field = getattr(obj, lh_field.get_attname())
Expand Down Expand Up @@ -877,7 +878,8 @@ def add(self, *objs):
'$each': list(new_fks)
}
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)

add.alters_data = True
Expand All @@ -903,7 +905,8 @@ def _remove(self, to_del):
'$in': list(to_del)
}
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)

def clear(self):
Expand All @@ -914,7 +917,8 @@ def clear(self):
'$set': {
self.field.attname: []
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)
setattr(self.instance, self.field.attname, set())

Expand Down
67 changes: 63 additions & 4 deletions djongo/sql2mongo/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
self.is_negated = False
self._name = name
self.precedence = OPERATOR_PRECEDENCE[name]

self._options = {}

def negate(self):
raise NotImplementedError
Expand Down Expand Up @@ -508,6 +510,22 @@ def to_mongo(self):
return self._op.to_mongo()


class CaseOp(_Op, _StatementParser):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.statement.skip(2)
self._statement2ops()
self.evaluate()

def negate(self):
raise NotImplementedError

def to_mongo(self):
return self._op.to_mongo()


class ParenthesisOp(_Op, _StatementParser):

def to_mongo(self):
Expand All @@ -525,14 +543,16 @@ def negate(self):

class CmpOp(_Op):


def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._identifier = SQLToken.token2sql(self.statement.left, self.query)

if isinstance(self.statement.right, Identifier):
raise SQLDecodeError('Join using WHERE not supported')

self._operator = OPERATOR_MAP[self.statement.token_next(0)[1].value]
sql_operation = self.statement.token_next(0)[1].value
self._operator = OPERATOR_MAP[sql_operation]
index = re_index(self.statement.right.value)

if self._operator in NEW_OPERATORS:
Expand All @@ -541,6 +561,12 @@ def __init__(self, *args, **kwargs):
else:
self._constant = self.params[index] if index is not None else MAP_INDEX_NONE[self.statement.right.value]

if sql_operation == 'iLIKE':
self._options = {'$options': 'im'}
self._constant = self._make_regex(self._constant)
if sql_operation == 'LIKE':
self._constant = self._make_regex(self._constant)

if isinstance(self._constant, dict):
self._field_ext, self._constant = next(iter(self._constant.items()))
else:
Expand All @@ -558,9 +584,40 @@ def to_mongo(self):
field += '.' + self._field_ext

if not self.is_negated:
return {field: {self._operator: self._constant}}
return {field: {self._operator: self._constant, **self._options}}
else:
return {field: {'$not': {self._operator: self._constant, **self._options}}}

@staticmethod
def _check_embedded(to_match):
try:
check_dict = to_match
replace_chars = "\\%'"
for c in replace_chars:
if c == "'":
check_dict = check_dict.replace("'", '"')
else:
check_dict = check_dict.replace(c, "")
check_dict = json.loads(check_dict)
if isinstance(check_dict, dict):
return check_dict
else:
return to_match
except Exception as e:
return to_match

@staticmethod
def _make_regex(to_match):
to_match = CmpOp._check_embedded(to_match)
if isinstance(to_match, str):
to_match = to_match.replace('%', '.*')
regex = '^' + to_match + '$'
elif isinstance(to_match, dict):
field_ext, to_match = next(iter(to_match.items()))
regex = to_match
else:
return {field: {'$not': {self._operator: self._constant}}}
raise SQLDecodeError
return regex


OPERATOR_MAP = {
Expand All @@ -570,7 +627,9 @@ def to_mongo(self):
'>=': '$gte',
'<=': '$lte',
'IN': '$in',
'NOT IN': '$nin'
'NOT IN': '$nin',
'LIKE': '$regex',
'iLIKE': '$regex',
}
OPERATOR_PRECEDENCE = {
'IS': 8,
Expand Down
Loading