Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
- main

jobs:
build:
test-sqlite:
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand All @@ -33,3 +33,46 @@ jobs:
- name: Test with pytest
run: |
pytest --color=yes

test-mysql:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, "3.10", "3.11", "3.12"]

services:
mysql:
image: mysql:8.0
env:
MYSQL_ROOT_PASSWORD: root
MYSQL_DATABASE: ingestify_test
ports:
- 3306:3306
options: >-
--health-cmd="mysqladmin ping"
--health-interval=10s
--health-timeout=5s
--health-retries=3

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[test]"
- name: Install MySQL dependencies
run: |
pip install mysqlclient
- name: Code formatting
run: |
pip install black==22.3.0
black --check .
- name: Test with pytest
env:
INGESTIFY_TEST_DATABASE_URL: mysql://root:root@127.0.0.1:3306/ingestify_test
run: |
pytest --color=yes
50 changes: 37 additions & 13 deletions ingestify/infra/store/dataset/sqlalchemy/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, url: str, table_prefix: str = ""):
self._init_engine()

# Create all tables in the database
self.metadata.create_all(self.engine)
self.create_all_tables()

def __del__(self):
self.close()
Expand All @@ -143,6 +143,14 @@ def close(self):
if hasattr(self, "engine"):
self.engine.dispose()

def create_all_tables(self):
self.metadata.create_all(self.engine)

def drop_all_tables(self):
"""Drop all tables in the database. Useful for test cleanup."""
if hasattr(self, "metadata") and hasattr(self, "engine"):
self.metadata.drop_all(self.engine)

def get(self):
return self.session()

Expand Down Expand Up @@ -208,18 +216,33 @@ def _upsert(

primary_key_columns = [column for column in table.columns if column.primary_key]

if immutable_rows:
stmt = stmt.on_conflict_do_nothing(index_elements=primary_key_columns)
if dialect == "mysql":
# MySQL uses ON DUPLICATE KEY UPDATE syntax
if immutable_rows:
# For MySQL immutable rows, use INSERT IGNORE to skip duplicates
stmt = stmt.prefix_with("IGNORE")
else:
# MySQL uses stmt.inserted instead of stmt.excluded
set_ = {
name: stmt.inserted[name]
for name, column in table.columns.items()
if column not in primary_key_columns
}
stmt = stmt.on_duplicate_key_update(set_)
else:
set_ = {
name: getattr(stmt.excluded, name)
for name, column in table.columns.items()
if column not in primary_key_columns
}

stmt = stmt.on_conflict_do_update(
index_elements=primary_key_columns, set_=set_
)
# PostgreSQL and SQLite use ON CONFLICT syntax
if immutable_rows:
stmt = stmt.on_conflict_do_nothing(index_elements=primary_key_columns)
else:
set_ = {
name: getattr(stmt.excluded, name)
for name, column in table.columns.items()
if column not in primary_key_columns
}

stmt = stmt.on_conflict_do_update(
index_elements=primary_key_columns, set_=set_
)

connection.execute(stmt)

Expand All @@ -242,7 +265,8 @@ def _build_cte_sqlite(self, records, name: str) -> CTE:
def _build_cte(self, records: list[dict], name: str) -> CTE:
"""Build a CTE from a list of dictionaries."""

if self.dialect.name == "sqlite":
if self.dialect.name in ("sqlite", "mysql"):
# SQLite and MySQL don't support VALUES syntax, use UNION ALL instead
return self._build_cte_sqlite(records, name)

first_row = records[0]
Expand Down
21 changes: 19 additions & 2 deletions ingestify/infra/store/dataset/sqlalchemy/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,31 @@ class TZDateTime(TypeDecorator):
LOCAL_TIMEZONE = datetime.datetime.utcnow().astimezone().tzinfo
cache_ok = True

def __init__(self, fsp=None, **kwargs):
super().__init__(**kwargs)
self.fsp = fsp

def load_dialect_impl(self, dialect):
# For MySQL, use DATETIME with fractional seconds precision
if dialect.name == "mysql" and self.fsp is not None:
from sqlalchemy.dialects.mysql import DATETIME as MySQL_DATETIME

# Return the type without type_descriptor to ensure our process methods are called
return MySQL_DATETIME(fsp=self.fsp)
return super().load_dialect_impl(dialect)

def process_bind_param(self, value: Optional[datetime.datetime], dialect):
if not value:
return None

if value.tzinfo is None:
value = value.astimezone(self.LOCAL_TIMEZONE)
# Assume naive datetimes are already in UTC
value = value.replace(tzinfo=datetime.timezone.utc)
else:
# Convert timezone-aware datetimes to UTC
value = value.astimezone(datetime.timezone.utc)

return value.astimezone(datetime.timezone.utc)
return value

def process_result_value(self, value, dialect):
if not value:
Expand Down
7 changes: 5 additions & 2 deletions ingestify/tests/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
main:
# Cannot use in memory data because database is shared between processes
metadata_url: !ENV "sqlite:///${TEST_DIR}/main.db"
# For MySQL tests, INGESTIFY_TEST_DATABASE_URL will be set
# For SQLite tests, falls back to using TEST_DIR
metadata_url: !ENV ${INGESTIFY_TEST_DATABASE_URL}
metadata_options:
table_prefix: !ENV ${INGESTIFY_TEST_DATABASE_PREFIX}_
file_url: !ENV file://${TEST_DIR}/data
default_bucket: main

Expand Down
48 changes: 45 additions & 3 deletions ingestify/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import tempfile

import pytest
import os

import pytest

from ingestify.main import get_engine


@pytest.fixture(scope="function", autouse=True)
def datastore_dir():
Expand All @@ -12,6 +15,45 @@ def datastore_dir():
yield tmpdirname


@pytest.fixture(scope="session")
def config_file():
@pytest.fixture(scope="function")
def ingestify_test_database_url(datastore_dir, monkeypatch):
key = "INGESTIFY_TEST_DATABASE_URL"

value = os.environ.get(key)
if value is None:
value = f"sqlite:///{datastore_dir}/main.db"
monkeypatch.setenv(key, value)

yield value


@pytest.fixture(scope="function")
def config_file(ingestify_test_database_url):
# Depend on ingestify_test_database_url to make sure environment variables are set in time, also make sure database is
# cleaned before ingestify opens a connection
return os.path.abspath(os.path.dirname(__file__) + "/config.yaml")


@pytest.fixture
def db_cleanup():
def do_cleanup(engine):
# # Close connections after test
session_provider = getattr(
engine.store.dataset_repository, "session_provider", None
)
if session_provider:
session_provider.session.remove()
session_provider.engine.dispose()
session_provider.drop_all_tables()

return do_cleanup


@pytest.fixture(scope="function")
def engine(config_file, db_cleanup):
# Now create the engine for the test
engine = get_engine(config_file, "main")

yield engine

db_cleanup(engine)
39 changes: 13 additions & 26 deletions ingestify/tests/test_auto_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ingestify.domain.models.fetch_policy import FetchPolicy
from ingestify.domain import Selector, DataSpecVersionCollection
from ingestify import Source, DatasetResource
from ingestify.utils import utcnow


class MockSource(Source):
Expand Down Expand Up @@ -39,7 +40,7 @@ def find_datasets(
url="http://test.com/match1",
).add_file(
data_feed_key="test",
last_modified=datetime.datetime.now(),
last_modified=utcnow(),
json_content={"blaat": "piet"},
)

Expand Down Expand Up @@ -75,7 +76,7 @@ def find_datasets(
url="http://test.com/match1",
).add_file(
data_feed_key="test",
last_modified=datetime.datetime.now(),
last_modified=utcnow(),
json_content={"competition_id": 11},
)
elif competition_id == 22:
Expand All @@ -91,7 +92,7 @@ def find_datasets(
url="http://test.com/match2",
).add_file(
data_feed_key="test",
last_modified=datetime.datetime.now(),
last_modified=utcnow(),
json_content={"competition_id": 22},
)

Expand All @@ -106,10 +107,8 @@ def discover_selectors(self, dataset_type: str):
]


def test_iter_datasets_basic_auto_ingest(config_file):
def test_iter_datasets_basic_auto_ingest(engine):
"""Test basic auto-ingest functionality."""
engine = get_engine(config_file)

# Add a simple ingestion plan
mock_source = MockSource(name="test_source")
data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}})
Expand Down Expand Up @@ -141,20 +140,16 @@ def test_iter_datasets_basic_auto_ingest(config_file):
assert datasets[0].identifier["competition_id"] == 11


def test_iter_datasets_auto_ingest_disabled(config_file):
def test_iter_datasets_auto_ingest_disabled(engine):
"""Test that auto_ingest=False returns only existing datasets."""
engine = get_engine(config_file)

# Should only return existing datasets (none in empty store)
datasets = list(engine.iter_datasets(competition_id=11, auto_ingest=False))

assert len(datasets) == 0


def test_iter_datasets_outside_config_scope(config_file):
def test_iter_datasets_outside_config_scope(engine):
"""Test that requests outside IngestionPlan scope return nothing."""
engine = get_engine(config_file)

# Add plan only for competition_id=11
mock_source = MockSource(name="test_source")
data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}})
Expand All @@ -180,10 +175,8 @@ def test_iter_datasets_outside_config_scope(config_file):
assert len(datasets) == 0


def test_iter_datasets_discover_selectors_with_filters(config_file):
def test_iter_datasets_discover_selectors_with_filters(engine):
"""Test that selector_filters are applied after discover_selectors runs."""
engine = get_engine(config_file)

# Create an IngestionPlan with empty selector - this will trigger discover_selectors
mock_source = MockSourceWithDiscoverSelectors(name="test_source_discover")
data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}})
Expand Down Expand Up @@ -216,10 +209,8 @@ def test_iter_datasets_discover_selectors_with_filters(config_file):
assert datasets[0].name == "Mock match comp 11"


def test_iter_datasets_discover_selectors_multiple_matches(config_file):
def test_iter_datasets_discover_selectors_multiple_matches(engine):
"""Test that multiple discovered selectors can match the filters."""
engine = get_engine(config_file)

# Create an IngestionPlan with empty selector - this will trigger discover_selectors
mock_source = MockSourceWithDiscoverSelectors(name="test_source_discover")
data_spec_versions = DataSpecVersionCollection.from_dict({"default": {"v1"}})
Expand Down Expand Up @@ -248,12 +239,10 @@ def test_iter_datasets_discover_selectors_multiple_matches(config_file):
assert competition_ids == {11, 22}


def test_selector_filters_make_discovered_selectors_more_strict(config_file):
def test_selector_filters_make_discovered_selectors_more_strict(engine):
"""Test that when selector_filters are more strict than discovered selectors, we make the selectors more strict."""
from unittest.mock import Mock

engine = get_engine(config_file)

# Create a source that returns multiple matches per season
class MockSourceMultipleMatches(Source):
@property
Expand Down Expand Up @@ -291,7 +280,7 @@ def find_datasets(
url=f"http://test.com/match{mid}",
).add_file(
data_feed_key="test",
last_modified=datetime.datetime.now(),
last_modified=utcnow(),
json_content={"match_id": mid},
)
return []
Expand Down Expand Up @@ -348,13 +337,11 @@ def discover_selectors(self, dataset_type):
# Without this optimization, we'd call with match_id=None and fetch 3 matches instead of 1


def test_iter_datasets_with_open_data_auto_discovery(config_file):
def test_iter_datasets_with_open_data_auto_discovery(engine):
"""Test that use_open_data=True auto-discovers open data sources without configuration."""
from unittest.mock import Mock
from ingestify.application import loader

engine = get_engine(config_file)

# Create mock source class that inherits from Source
class MockOpenDataSource(Source):
def __init__(self, name):
Expand Down Expand Up @@ -387,7 +374,7 @@ def find_datasets(
url="http://open-data.com/match123",
).add_file(
data_feed_key="test",
last_modified=datetime.datetime.now(),
last_modified=utcnow(),
json_content={"match_id": 123},
)

Expand Down
Loading