diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4d1ddb5..a83f11e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,7 @@ on: - main jobs: - build: + test-sqlite: runs-on: ${{ matrix.os }} strategy: matrix: @@ -33,3 +33,46 @@ jobs: - name: Test with pytest run: | pytest --color=yes + + test-mysql: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9, "3.10", "3.11", "3.12"] + + services: + mysql: + image: mysql:8.0 + env: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: ingestify_test + ports: + - 3306:3306 + options: >- + --health-cmd="mysqladmin ping" + --health-interval=10s + --health-timeout=5s + --health-retries=3 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + - name: Install MySQL dependencies + run: | + pip install mysqlclient + - name: Code formatting + run: | + pip install black==22.3.0 + black --check . + - name: Test with pytest + env: + INGESTIFY_TEST_DATABASE_URL: mysql://root:root@127.0.0.1:3306/ingestify_test + run: | + pytest --color=yes diff --git a/ingestify/infra/store/dataset/sqlalchemy/repository.py b/ingestify/infra/store/dataset/sqlalchemy/repository.py index 51d7551..7627ec5 100644 --- a/ingestify/infra/store/dataset/sqlalchemy/repository.py +++ b/ingestify/infra/store/dataset/sqlalchemy/repository.py @@ -130,7 +130,7 @@ def __init__(self, url: str, table_prefix: str = ""): self._init_engine() # Create all tables in the database - self.metadata.create_all(self.engine) + self.create_all_tables() def __del__(self): self.close() @@ -143,6 +143,14 @@ def close(self): if hasattr(self, "engine"): self.engine.dispose() + def create_all_tables(self): + self.metadata.create_all(self.engine) + + def drop_all_tables(self): + """Drop all tables in the database. Useful for test cleanup.""" + if hasattr(self, "metadata") and hasattr(self, "engine"): + self.metadata.drop_all(self.engine) + def get(self): return self.session() @@ -208,18 +216,33 @@ def _upsert( primary_key_columns = [column for column in table.columns if column.primary_key] - if immutable_rows: - stmt = stmt.on_conflict_do_nothing(index_elements=primary_key_columns) + if dialect == "mysql": + # MySQL uses ON DUPLICATE KEY UPDATE syntax + if immutable_rows: + # For MySQL immutable rows, use INSERT IGNORE to skip duplicates + stmt = stmt.prefix_with("IGNORE") + else: + # MySQL uses stmt.inserted instead of stmt.excluded + set_ = { + name: stmt.inserted[name] + for name, column in table.columns.items() + if column not in primary_key_columns + } + stmt = stmt.on_duplicate_key_update(set_) else: - set_ = { - name: getattr(stmt.excluded, name) - for name, column in table.columns.items() - if column not in primary_key_columns - } - - stmt = stmt.on_conflict_do_update( - index_elements=primary_key_columns, set_=set_ - ) + # PostgreSQL and SQLite use ON CONFLICT syntax + if immutable_rows: + stmt = stmt.on_conflict_do_nothing(index_elements=primary_key_columns) + else: + set_ = { + name: getattr(stmt.excluded, name) + for name, column in table.columns.items() + if column not in primary_key_columns + } + + stmt = stmt.on_conflict_do_update( + index_elements=primary_key_columns, set_=set_ + ) connection.execute(stmt) @@ -242,7 +265,8 @@ def _build_cte_sqlite(self, records, name: str) -> CTE: def _build_cte(self, records: list[dict], name: str) -> CTE: """Build a CTE from a list of dictionaries.""" - if self.dialect.name == "sqlite": + if self.dialect.name in ("sqlite", "mysql"): + # SQLite and MySQL don't support VALUES syntax, use UNION ALL instead return self._build_cte_sqlite(records, name) first_row = records[0] diff --git a/ingestify/infra/store/dataset/sqlalchemy/tables.py b/ingestify/infra/store/dataset/sqlalchemy/tables.py index c164de9..fc9254b 100644 --- a/ingestify/infra/store/dataset/sqlalchemy/tables.py +++ b/ingestify/infra/store/dataset/sqlalchemy/tables.py @@ -51,14 +51,31 @@ class TZDateTime(TypeDecorator): LOCAL_TIMEZONE = datetime.datetime.utcnow().astimezone().tzinfo cache_ok = True + def __init__(self, fsp=None, **kwargs): + super().__init__(**kwargs) + self.fsp = fsp + + def load_dialect_impl(self, dialect): + # For MySQL, use DATETIME with fractional seconds precision + if dialect.name == "mysql" and self.fsp is not None: + from sqlalchemy.dialects.mysql import DATETIME as MySQL_DATETIME + + # Return the type without type_descriptor to ensure our process methods are called + return MySQL_DATETIME(fsp=self.fsp) + return super().load_dialect_impl(dialect) + def process_bind_param(self, value: Optional[datetime.datetime], dialect): if not value: return None if value.tzinfo is None: - value = value.astimezone(self.LOCAL_TIMEZONE) + # Assume naive datetimes are already in UTC + value = value.replace(tzinfo=datetime.timezone.utc) + else: + # Convert timezone-aware datetimes to UTC + value = value.astimezone(datetime.timezone.utc) - return value.astimezone(datetime.timezone.utc) + return value def process_result_value(self, value, dialect): if not value: diff --git a/ingestify/tests/config.yaml b/ingestify/tests/config.yaml index 175032b..1318297 100644 --- a/ingestify/tests/config.yaml +++ b/ingestify/tests/config.yaml @@ -1,6 +1,9 @@ main: - # Cannot use in memory data because database is shared between processes - metadata_url: !ENV "sqlite:///${TEST_DIR}/main.db" + # For MySQL tests, INGESTIFY_TEST_DATABASE_URL will be set + # For SQLite tests, falls back to using TEST_DIR + metadata_url: !ENV ${INGESTIFY_TEST_DATABASE_URL} + metadata_options: + table_prefix: !ENV ${INGESTIFY_TEST_DATABASE_PREFIX}_ file_url: !ENV file://${TEST_DIR}/data default_bucket: main diff --git a/ingestify/tests/conftest.py b/ingestify/tests/conftest.py index 33faeae..d87cf56 100644 --- a/ingestify/tests/conftest.py +++ b/ingestify/tests/conftest.py @@ -1,8 +1,11 @@ import tempfile -import pytest import os +import pytest + +from ingestify.main import get_engine + @pytest.fixture(scope="function", autouse=True) def datastore_dir(): @@ -12,6 +15,45 @@ def datastore_dir(): yield tmpdirname -@pytest.fixture(scope="session") -def config_file(): +@pytest.fixture(scope="function") +def ingestify_test_database_url(datastore_dir, monkeypatch): + key = "INGESTIFY_TEST_DATABASE_URL" + + value = os.environ.get(key) + if value is None: + value = f"sqlite:///{datastore_dir}/main.db" + monkeypatch.setenv(key, value) + + yield value + + +@pytest.fixture(scope="function") +def config_file(ingestify_test_database_url): + # Depend on ingestify_test_database_url to make sure environment variables are set in time, also make sure database is + # cleaned before ingestify opens a connection return os.path.abspath(os.path.dirname(__file__) + "/config.yaml") + + +@pytest.fixture +def db_cleanup(): + def do_cleanup(engine): + # # Close connections after test + session_provider = getattr( + engine.store.dataset_repository, "session_provider", None + ) + if session_provider: + session_provider.session.remove() + session_provider.engine.dispose() + session_provider.drop_all_tables() + + return do_cleanup + + +@pytest.fixture(scope="function") +def engine(config_file, db_cleanup): + # Now create the engine for the test + engine = get_engine(config_file, "main") + + yield engine + + db_cleanup(engine) diff --git a/ingestify/tests/test_auto_ingest.py b/ingestify/tests/test_auto_ingest.py index 0073b9a..83bcdf0 100644 --- a/ingestify/tests/test_auto_ingest.py +++ b/ingestify/tests/test_auto_ingest.py @@ -8,6 +8,7 @@ from ingestify.domain.models.fetch_policy import FetchPolicy from ingestify.domain import Selector, DataSpecVersionCollection from ingestify import Source, DatasetResource +from ingestify.utils import utcnow class MockSource(Source): @@ -39,7 +40,7 @@ def find_datasets( url="http://test.com/match1", ).add_file( data_feed_key="test", - last_modified=datetime.datetime.now(), + last_modified=utcnow(), json_content={"blaat": "piet"}, ) @@ -75,7 +76,7 @@ def find_datasets( url="http://test.com/match1", ).add_file( data_feed_key="test", - last_modified=datetime.datetime.now(), + last_modified=utcnow(), json_content={"competition_id": 11}, ) elif competition_id == 22: @@ -91,7 +92,7 @@ def find_datasets( url="http://test.com/match2", ).add_file( data_feed_key="test", - last_modified=datetime.datetime.now(), + last_modified=utcnow(), json_content={"competition_id": 22}, ) @@ -106,10 +107,8 @@ def discover_selectors(self, dataset_type: str): ] -def test_iter_datasets_basic_auto_ingest(config_file): +def test_iter_datasets_basic_auto_ingest(engine): """Test basic auto-ingest functionality.""" - engine = get_engine(config_file) - # Add a simple ingestion plan mock_source = MockSource(name="test_source") data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}}) @@ -141,20 +140,16 @@ def test_iter_datasets_basic_auto_ingest(config_file): assert datasets[0].identifier["competition_id"] == 11 -def test_iter_datasets_auto_ingest_disabled(config_file): +def test_iter_datasets_auto_ingest_disabled(engine): """Test that auto_ingest=False returns only existing datasets.""" - engine = get_engine(config_file) - # Should only return existing datasets (none in empty store) datasets = list(engine.iter_datasets(competition_id=11, auto_ingest=False)) assert len(datasets) == 0 -def test_iter_datasets_outside_config_scope(config_file): +def test_iter_datasets_outside_config_scope(engine): """Test that requests outside IngestionPlan scope return nothing.""" - engine = get_engine(config_file) - # Add plan only for competition_id=11 mock_source = MockSource(name="test_source") data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}}) @@ -180,10 +175,8 @@ def test_iter_datasets_outside_config_scope(config_file): assert len(datasets) == 0 -def test_iter_datasets_discover_selectors_with_filters(config_file): +def test_iter_datasets_discover_selectors_with_filters(engine): """Test that selector_filters are applied after discover_selectors runs.""" - engine = get_engine(config_file) - # Create an IngestionPlan with empty selector - this will trigger discover_selectors mock_source = MockSourceWithDiscoverSelectors(name="test_source_discover") data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}}) @@ -216,10 +209,8 @@ def test_iter_datasets_discover_selectors_with_filters(config_file): assert datasets[0].name == "Mock match comp 11" -def test_iter_datasets_discover_selectors_multiple_matches(config_file): +def test_iter_datasets_discover_selectors_multiple_matches(engine): """Test that multiple discovered selectors can match the filters.""" - engine = get_engine(config_file) - # Create an IngestionPlan with empty selector - this will trigger discover_selectors mock_source = MockSourceWithDiscoverSelectors(name="test_source_discover") data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}}) @@ -248,12 +239,10 @@ def test_iter_datasets_discover_selectors_multiple_matches(config_file): assert competition_ids == {11, 22} -def test_selector_filters_make_discovered_selectors_more_strict(config_file): +def test_selector_filters_make_discovered_selectors_more_strict(engine): """Test that when selector_filters are more strict than discovered selectors, we make the selectors more strict.""" from unittest.mock import Mock - engine = get_engine(config_file) - # Create a source that returns multiple matches per season class MockSourceMultipleMatches(Source): @property @@ -291,7 +280,7 @@ def find_datasets( url=f"http://test.com/match{mid}", ).add_file( data_feed_key="test", - last_modified=datetime.datetime.now(), + last_modified=utcnow(), json_content={"match_id": mid}, ) return [] @@ -348,13 +337,11 @@ def discover_selectors(self, dataset_type): # Without this optimization, we'd call with match_id=None and fetch 3 matches instead of 1 -def test_iter_datasets_with_open_data_auto_discovery(config_file): +def test_iter_datasets_with_open_data_auto_discovery(engine): """Test that use_open_data=True auto-discovers open data sources without configuration.""" from unittest.mock import Mock from ingestify.application import loader - engine = get_engine(config_file) - # Create mock source class that inherits from Source class MockOpenDataSource(Source): def __init__(self, name): @@ -387,7 +374,7 @@ def find_datasets( url="http://open-data.com/match123", ).add_file( data_feed_key="test", - last_modified=datetime.datetime.now(), + last_modified=utcnow(), json_content={"match_id": 123}, ) diff --git a/ingestify/tests/test_engine.py b/ingestify/tests/test_engine.py index a7ec4d1..aab0ee0 100644 --- a/ingestify/tests/test_engine.py +++ b/ingestify/tests/test_engine.py @@ -25,7 +25,8 @@ from ingestify.domain.models.fetch_policy import FetchPolicy from ingestify.domain.models.task.task_summary import TaskState from ingestify.infra.serialization import serialize, deserialize -from ingestify.main import get_engine, get_dev_engine +from ingestify.main import get_dev_engine +from ingestify.utils import utcnow def add_ingestion_plan(engine: IngestionEngine, source: Source, **selector): @@ -78,7 +79,7 @@ def find_datasets( season_id, **kwargs, ): - last_modified = datetime.now(pytz.utc) + last_modified = utcnow() yield ( DatasetResource( @@ -273,9 +274,7 @@ def find_datasets( ) -def test_engine(config_file): - engine = get_engine(config_file, "main") - +def test_engine(engine): add_ingestion_plan( engine, SimpleFakeSource("fake-source"), competition_id=1, season_id=2 ) @@ -293,6 +292,7 @@ def test_engine(config_file): dataset = datasets.first() assert dataset.identifier == Identifier(competition_id=1, season_id=2, match_id=1) + assert len(dataset.revisions) == 2 assert len(dataset.revisions[0].modified_files) == 3 assert len(dataset.revisions[1].modified_files) == 1 @@ -325,13 +325,11 @@ def test_engine(config_file): assert dataset.last_modified_at is not None -def test_iterator_source(config_file): +def test_iterator_source(engine): """Test when a Source returns a Iterator to do Batch processing. Every batch must be executed right away. """ - engine = get_engine(config_file, "main") - batch_source = None def callback(idx): @@ -339,7 +337,7 @@ def callback(idx): datasets = engine.store.get_dataset_collection() assert len(datasets) == idx - if idx == 1000: + if idx == 100: batch_source.should_stop = True batch_source = BatchSource("fake-source", callback) @@ -348,7 +346,7 @@ def callback(idx): engine.load() datasets = engine.store.get_dataset_collection() - assert len(datasets) == 1000 + assert len(datasets) == 100 for dataset in datasets: assert len(dataset.revisions) == 1 @@ -357,14 +355,14 @@ def callback(idx): batch_source.should_stop = False def callback(idx): - if idx == 1000: + if idx == 100: batch_source.should_stop = True batch_source.callback = callback engine.load() datasets = engine.store.get_dataset_collection() - assert len(datasets) == 1000 + assert len(datasets) == 100 for dataset in datasets: assert len(dataset.revisions) == 2 @@ -373,9 +371,7 @@ def callback(idx): deserialize(s) -def test_ingestion_plan_failing_task(config_file): - engine = get_engine(config_file, "main") - +def test_ingestion_plan_failing_task(engine): source = FailingLoadSource("fake-source") add_ingestion_plan(engine, source, competition_id=1, season_id=2) @@ -387,9 +383,7 @@ def test_ingestion_plan_failing_task(config_file): assert items[0].task_summaries[0].state == TaskState.FAILED -def test_ingestion_plan_failing_job(config_file): - engine = get_engine(config_file, "main") - +def test_ingestion_plan_failing_job(engine): source = FailingJobSource("fake-source") add_ingestion_plan(engine, source, competition_id=1, season_id=2) @@ -412,9 +406,7 @@ def test_change_partition_key_transformer(): """ -def test_serde(config_file): - engine = get_engine(config_file, "main") - +def test_serde(engine): add_ingestion_plan( engine, SimpleFakeSource("fake-source"), competition_id=1, season_id=2 ) @@ -434,10 +426,8 @@ def test_serde(config_file): assert event.model_dump_json() == deserialized_event.model_dump_json() -def test_empty_dataset_resource_id(config_file): +def test_empty_dataset_resource_id(engine): """When a empty DatasetResourceId is passed nothing should break""" - engine = get_engine(config_file, "main") - add_ingestion_plan(engine, EmptyDatasetResourceIdSource("fake-source")) engine.load() @@ -509,9 +499,8 @@ def find_datasets( ) -def test_post_load_files_hook(config_file): +def test_post_load_files_hook(engine): """Test that post_load_files hook changes state from SCHEDULED to COMPLETE when content is not empty.""" - engine = get_engine(config_file, "main") add_ingestion_plan(engine, SourceWithHook("test"), competition_id=1, season_id=2) # First run: file contains '{}', state should remain SCHEDULED @@ -525,10 +514,8 @@ def test_post_load_files_hook(config_file): assert dataset2.state == DatasetState.COMPLETE -def test_force_save_creates_revision(config_file): +def test_force_save_creates_revision(engine): """Test that datasets get a revision even when no files are persisted.""" - engine = get_engine(config_file, "main") - # Create one dataset with files and one without add_ingestion_plan( engine, SimpleFakeSource("fake-source"), competition_id=1, season_id=2 @@ -552,7 +539,9 @@ def test_force_save_creates_revision(config_file): season_id=2 ).first() - dataset_without_files = engine.store.get_dataset_collection(metadata_only=True) + dataset_without_files = engine.store.get_dataset_collection( + season_id=2, metadata_only=True + ) assert ( dataset_without_files.metadata.last_modified == dataset_with_last_modified.last_modified_at diff --git a/ingestify/tests/test_file_cache.py b/ingestify/tests/test_file_cache.py index 3312a8e..cd7e560 100644 --- a/ingestify/tests/test_file_cache.py +++ b/ingestify/tests/test_file_cache.py @@ -8,10 +8,9 @@ from ingestify.domain.models.dataset.revision import RevisionSource, SourceType -def test_file_cache(config_file): +def test_file_cache(engine): """Test file caching with the with_file_cache context manager.""" # Get engine from the fixture - engine = get_engine(config_file, "main") store = engine.store # Create a timestamp for test data diff --git a/ingestify/tests/test_pagination.py b/ingestify/tests/test_pagination.py index 075e4f1..81813a2 100644 --- a/ingestify/tests/test_pagination.py +++ b/ingestify/tests/test_pagination.py @@ -6,10 +6,9 @@ from ingestify.main import get_engine -def test_iter_dataset_collection_batches(config_file): +def test_iter_dataset_collection_batches(engine): """Test iteration over datasets with batches using iter_dataset_collection_batches.""" # Get engine from the fixture - engine = get_engine(config_file, "main") store = engine.store bucket = store.bucket @@ -81,10 +80,9 @@ def test_iter_dataset_collection_batches(config_file): assert filtered_dataset_ids[0] == "dataset-5" -def test_dataset_state_filter(config_file): +def test_dataset_state_filter(engine): """Test filtering datasets by state.""" # Get engine from the fixture - engine = get_engine(config_file, "main") store = engine.store bucket = store.bucket diff --git a/ingestify/tests/test_store_version.py b/ingestify/tests/test_store_version.py index 8ae21c3..9a1861c 100644 --- a/ingestify/tests/test_store_version.py +++ b/ingestify/tests/test_store_version.py @@ -4,7 +4,7 @@ from ingestify.main import get_engine -def test_store_version_tracking_new_store(config_file): +def test_store_version_tracking_new_store(config_file, db_cleanup): """Test that a new store gets initialized with the current version.""" with patch("ingestify.__version__", "1.0.0"): engine = get_engine(config_file) @@ -13,8 +13,10 @@ def test_store_version_tracking_new_store(config_file): stored_version = engine.store.dataset_repository.get_store_version() assert stored_version == "1.0.0" + db_cleanup(engine) -def test_store_version_tracking_existing_store_same_version(config_file): + +def test_store_version_tracking_existing_store_same_version(config_file, db_cleanup): """Test that an existing store with same version doesn't cause issues.""" with patch("ingestify.__version__", "1.0.0"): # Initialize store first time @@ -29,16 +31,20 @@ def test_store_version_tracking_existing_store_same_version(config_file): stored_version = store2.dataset_repository.get_store_version() assert stored_version == "1.0.0" + db_cleanup(engine1) + -def test_store_version_tracking_version_mismatch(config_file, caplog): +def test_store_version_tracking_version_mismatch(config_file, caplog, db_cleanup): """Test that version mismatch is logged as warning.""" - # Initialize store with version 1.0.0 - with patch("ingestify.__version__", "1.0.0"): + # Use engine as fixture as this cleans up the database + + # Initialize store with version 1.0.1 + with patch("ingestify.__version__", "1.0.1"): engine1 = get_engine(config_file) store1 = engine1.store stored_version = store1.dataset_repository.get_store_version() - assert stored_version == "1.0.0" + assert stored_version == "1.0.1" # Open store with different version with patch("ingestify.__version__", "2.0.0"): @@ -47,16 +53,17 @@ def test_store_version_tracking_version_mismatch(config_file, caplog): # Version should still be the original one stored_version = store2.dataset_repository.get_store_version() - assert stored_version == "1.0.0" + assert stored_version == "1.0.1" # Should have logged a warning about version mismatch assert "Store version mismatch" in caplog.text - assert "stored=1.0.0, current=2.0.0" in caplog.text + assert "stored=1.0.1, current=2.0.0" in caplog.text + + db_cleanup(engine1) -def test_store_version_methods(config_file): +def test_store_version_methods(engine): """Test the repository version methods directly.""" - engine = get_engine(config_file) repo = engine.store.dataset_repository from ingestify import __version__