Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions morango/sync/stream/deserialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, Generator, List, Optional, Type

from morango.models.certificates import Filter
from morango.models.core import Store, SyncableModel
from morango.registry import syncable_models
from morango.sync.stream.source import MorangoSource, SourceTask


class DeserializeTask(SourceTask):
"""Carrier class for providing context through the deserialization pipeline."""

__slots__ = ("store", "app_model", "fk_cache", "errors")

def __init__(self, store: Store):
self.store = store
self.app_model: Optional[SyncableModel] = None
self.fk_cache: Dict = {}
self.errors: List[Exception] = []

@property
def id(self) -> str:
return self.store.id

@property
def model(self) -> Type[SyncableModel]:
return syncable_models.get_model(self.store.profile, self.store.model_name)

@property
def has_errors(self) -> bool:
return len(self.errors) > 0

def set_app_model(self, app_model: Optional[SyncableModel]) -> None:
self.app_model = app_model

def add_error(self, error: Exception) -> None:
self.errors.append(error)


class StoreModelSource(MorangoSource[DeserializeTask]):
"""
Yields ``DeserializeTask`` objects for dirty store models that match the optional
*sync_filter*.
"""

def __init__(
self,
profile: str,
sync_filter: Optional[Filter] = None,
dirty_only: bool = True,
partition_order: str = "asc",
skip_errored: bool = True,
):
"""
:param profile: The Morango model profile
:param sync_filter: The Filter object for this sync
:param dirty_only: Whether to filter on dirty records only
:param partition_order: Controls how the filter specificity is applied, "asc" or "desc"
"""
super().__init__(profile, sync_filter, dirty_only, partition_order)
self.skip_errored = skip_errored

def stream_for_filter(
self, partition_condition: Optional[str]
) -> Generator[DeserializeTask, None, None]:
qs = Store.objects.filter(profile=self.profile)
if partition_condition is not None:
qs = qs.filter(partition__startswith=partition_condition)
if self.dirty_only:
qs = qs.filter(dirty_bit=True)
if self.skip_errored:
qs = qs.filter(deserialization_error="")
for store_model in qs.iterator():
yield DeserializeTask(store_model)
69 changes: 18 additions & 51 deletions morango/sync/stream/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Generator, Iterable, Iterator, List, Optional, Type

from django.core.serializers.json import DjangoJSONEncoder
from django.db.models import Q

from morango.models.certificates import Filter
from morango.models.core import (
Expand All @@ -16,13 +15,14 @@
SyncableModel,
)
from morango.registry import syncable_models
from morango.sync.stream.core import Buffer, Sink, Source, Transform, Unbuffer
from morango.sync.stream.core import Buffer, Sink, Transform, Unbuffer
from morango.sync.stream.source import MorangoSource, SourceTask
from morango.utils import self_referential_fk

logger = logging.getLogger(__name__)


class SerializeTask(object):
class SerializeTask(SourceTask):
"""Carrier class for providing context through the pipeline"""

__slots__ = ("model", "obj", "store", "counter")
Expand All @@ -33,6 +33,10 @@ def __init__(self, model: Type[SyncableModel], obj: SyncableModel):
self.store: Optional[Store] = None
self.counter: Optional[RecordMaxCounter] = None

@property
def id(self) -> str:
return self.obj.id

@property
def is_store_update(self):
return self.store is not None and not self.store._state.adding
Expand All @@ -52,59 +56,22 @@ def self_referential_fk(self) -> Optional[str]:
return self_referential_fk(self.model)


class AppModelSource(Source[SerializeTask]):
class AppModelSource(MorangoSource[SerializeTask]):
"""
Yields ``SerializeTask`` objects for every syncable-model record that matches the
optional *sync_filter*.
"""

def __init__(
self,
profile: str,
sync_filter: Optional[Filter] = None,
dirty_only: bool = True,
partition_order: str = "asc",
):
"""
:param profile: The Morango model profile
:param sync_filter: The Filter object for this sync
:param dirty_only: Whether to filter on dirty records only
:param partition_order: Controls how the filter specificity is applied, "asc" or "desc"
"""
self.profile = profile
self.sync_filter = sync_filter
self.dirty_only = dirty_only
self.partition_order = partition_order
self._seen = set()

def prefix_conditions(self) -> Generator[Optional[Q], None, None]:
if self.sync_filter is None:
# yield None once, so we do one query without a partition filter (everything)
yield None
else:
partitions_prefixes = [str(prefix) for prefix in self.sync_filter]
partition_iterator = sorted(
partitions_prefixes,
reverse=self.partition_order == "desc",
)

for prefix in partition_iterator:
yield Q(_morango_partition__startswith=prefix)

def stream(self) -> Generator[SerializeTask, None, None]:
for partition_condition in self.prefix_conditions():
for qs in syncable_models.get_model_querysets(self.profile):
if partition_condition is not None:
qs = qs.filter(partition_condition)
if self.dirty_only:
qs = qs.filter(_morango_dirty_bit=True)
for obj in qs.iterator():
# partition filtering could result in overlaps, and since we're walking
# through the partitions one by one, we should avoid duplicates. Morango
# syncable models have unique IDs across the entire profile
if obj.id not in self._seen:
self._seen.add(obj.id)
yield SerializeTask(qs.model, obj)
def stream_for_filter(
self, partition_condition: Optional[str]
) -> Generator[SerializeTask, None, None]:
for qs in syncable_models.get_model_querysets(self.profile):
if partition_condition is not None:
qs = qs.filter(_morango_partition__startswith=partition_condition)
if self.dirty_only:
qs = qs.filter(_morango_dirty_bit=True)
for obj in qs.iterator():
yield SerializeTask(qs.model, obj)


class StoreLookup(Transform[List[SerializeTask]]):
Expand Down
93 changes: 93 additions & 0 deletions morango/sync/stream/source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import abc
from typing import Generator, Iterator, Optional, TypeVar

from morango.models.certificates import Filter
from morango.sync.stream.core import Source


class SourceTask(abc.ABC):
"""Typing for source object passed through streaming pipeline"""

@property
@abc.abstractmethod
def id(self) -> str:
pass


T = TypeVar("T", bound=SourceTask)


class MorangoSource(Source[T], abc.ABC):
"""
Common source functionality for Morango sources, such as SyncableModels and Store records.
"""

def __init__(
self,
profile: str,
sync_filter: Optional[Filter] = None,
dirty_only: bool = True,
partition_order: str = "asc",
):
"""
:param profile: The Morango model profile
:param sync_filter: The Filter object for this sync
:param dirty_only: Whether to filter on dirty records only
:param partition_order: Controls how the filter specificity is applied, "asc" or "desc"
"""
self.profile = profile
self.sync_filter = sync_filter
self.dirty_only = dirty_only
self.partition_order = partition_order
self._seen = set()

def prefix_conditions(self) -> Generator[Optional[str], None, None]:
"""
Generates partition prefixes for queries based on the sync filter and partition order.

This method outputs prefixes in sorted order according to the specified partition
order. If no sync filter is provided, it yields `None` to indicate a query
without filtering by partition.

:return: A generator yielding partition prefixes or `None` if no filtering is applied.
"""
if self.sync_filter is None:
# yield None once, so we do one query without a partition filter (everything)
yield None
else:
partitions_prefixes = [str(prefix) for prefix in self.sync_filter]
partition_iterator = sorted(
partitions_prefixes,
reverse=self.partition_order == "desc",
)

for prefix in partition_iterator:
yield prefix

def stream(self) -> Generator[T, None, None]:
"""
Streams unique objects based on prefix conditions. This generator method iterates over
partition conditions defined in the sync_filter and passes thoughts to `stream_for_filter`
to stream back objects, ensuring that only objects with unique `id` values are yielded.

:return: A generator yielding unique objects.
"""
for partition_condition in self.prefix_conditions():
for obj in self.stream_for_filter(partition_condition):
# partition filtering could result in overlaps, and since we're walking
# through the partitions one by one, we should avoid duplicates. Morango
# syncable models and store records have unique IDs across the entire profile
if obj.id not in self._seen:
self._seen.add(obj.id)
yield obj

@abc.abstractmethod
def stream_for_filter(self, partition_condition: Optional[str]) -> Iterator[T]:
"""
This method is intended to generate an iterator that yields data based on the given
filtering condition.

:param partition_condition: A string representing a partition filter prefix condition
:return: An iterator yielding items
"""
pass
124 changes: 124 additions & 0 deletions tests/testapp/tests/sync/stream/test_deserialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import mock
from django.test import SimpleTestCase

from morango.models.certificates import Filter
from morango.models.core import Store, SyncableModel
from morango.sync.stream.deserialize import DeserializeTask, StoreModelSource


class DeserializeTaskTestCase(SimpleTestCase):
def setUp(self):
self.store = mock.Mock(spec_set=Store)
self.store.profile = "test"
self.store.model_name = "testmodel"
self.task = DeserializeTask(self.store)

@mock.patch("morango.sync.stream.deserialize.syncable_models.get_model")
def test_model(self, mock_get_model):
model = mock.Mock(spec_set=SyncableModel)
mock_get_model.return_value = model

self.assertEqual(self.task.model, model)
mock_get_model.assert_called_once_with("test", "testmodel")

def test_has_errors(self):
self.assertFalse(self.task.has_errors)
self.task.add_error(ValueError("bad data"))
self.assertTrue(self.task.has_errors)

def test_set_app_model(self):
app_model = mock.Mock(spec_set=SyncableModel)
self.task.set_app_model(app_model)
self.assertEqual(self.task.app_model, app_model)


class StoreModelSourceTestCase(SimpleTestCase):
def test_prefix_conditions__none(self):
source = StoreModelSource(profile="test")
conditions = list(source.prefix_conditions())
self.assertEqual(conditions, [None])

def test_prefix_conditions__with_filter_asc(self):
sync_filter = Filter("b\na")
source = StoreModelSource(profile="test", sync_filter=sync_filter, partition_order="asc")
conditions = list(source.prefix_conditions())
self.assertEqual(
conditions,
["a", "b"],
)

def test_prefix_conditions__with_filter_desc(self):
sync_filter = Filter("a\nb")
source = StoreModelSource(profile="test", sync_filter=sync_filter, partition_order="desc")
conditions = list(source.prefix_conditions())
self.assertEqual(
conditions,
["b", "a"],
)

@mock.patch("morango.sync.stream.deserialize.Store.objects.filter")
def test_stream__no_partition(self, mock_store_filter):
qs = mock.Mock()
mock_store_filter.return_value = qs
qs.filter.return_value = qs
store = mock.Mock(spec_set=Store)
store.id = "123"
qs.iterator.return_value = [store]

source = StoreModelSource(profile="test")
tasks = list(source.stream())

self.assertEqual(len(tasks), 1)
self.assertEqual(tasks[0].store, store)
mock_store_filter.assert_called_once_with(profile="test")
self.assertEqual(
qs.filter.call_args_list,
[mock.call(dirty_bit=True), mock.call(deserialization_error="")],
)

@mock.patch("morango.sync.stream.deserialize.Store.objects.filter")
def test_stream__partition_order_and_seen_once(self, mock_store_filter):
qs_a = mock.Mock()
qs_ab = mock.Mock()
qs_a.filter.return_value = qs_a
qs_ab.filter.return_value = qs_ab

first = mock.Mock(spec_set=Store)
first.id = "1"
second = mock.Mock(spec_set=Store)
second.id = "2"
third = mock.Mock(spec_set=Store)
third.id = "3"
qs_a.iterator.return_value = [first, second]
qs_ab.iterator.return_value = [second, third]

mock_store_filter.side_effect = [qs_a, qs_ab]

source = StoreModelSource(profile="test", sync_filter=Filter("a\na/b"))
tasks = list(source.stream())

self.assertEqual([task.store.id for task in tasks], ["1", "2", "3"])
self.assertEqual(
mock_store_filter.call_args_list,
[mock.call(profile="test"), mock.call(profile="test")],
)
self.assertEqual(
qs_a.filter.call_args_list[0],
mock.call(partition__startswith="a"),
)
self.assertEqual(
qs_ab.filter.call_args_list[0],
mock.call(partition__startswith="a/b"),
)

@mock.patch("morango.sync.stream.deserialize.Store.objects.filter")
def test_stream__no_dirty_filter(self, mock_store_filter):
qs = mock.Mock()
mock_store_filter.return_value = qs
qs.filter.return_value = qs
qs.iterator.return_value = []

source = StoreModelSource(profile="test", dirty_only=False)
list(source.stream())

self.assertEqual(qs.filter.call_args_list, [mock.call(deserialization_error="")])
Loading
Loading