diff --git a/changelog.d/added/601.md b/changelog.d/added/601.md new file mode 100644 index 00000000..2148d164 --- /dev/null +++ b/changelog.d/added/601.md @@ -0,0 +1 @@ +Add unified version manifest system for semver-based dataset versioning across GCS and Hugging Face, with rollback support and a single registry file (version_manifest.json) on both backends. Exposes `get_data_version()` and `get_data_manifest()` as public consumer APIs. diff --git a/policyengine_us_data/__init__.py b/policyengine_us_data/__init__.py index 17383534..493ff1b8 100644 --- a/policyengine_us_data/__init__.py +++ b/policyengine_us_data/__init__.py @@ -1,2 +1,3 @@ from .datasets import * from .geography import ZIP_CODE_DATASET +from .utils.version_manifest import get_data_version, get_data_manifest diff --git a/policyengine_us_data/tests/conftest.py b/policyengine_us_data/tests/conftest.py new file mode 100644 index 00000000..fb39787c --- /dev/null +++ b/policyengine_us_data/tests/conftest.py @@ -0,0 +1,63 @@ +"""Shared fixtures for version manifest tests.""" + +from unittest.mock import MagicMock + +import pytest + +from policyengine_us_data.utils.version_manifest import ( + HFVersionInfo, + GCSVersionInfo, + VersionManifest, + VersionRegistry, +) + + +@pytest.fixture +def sample_generations() -> dict[str, int]: + return { + "enhanced_cps_2024.h5": 1710203948123456, + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + } + + +@pytest.fixture +def sample_hf_info() -> HFVersionInfo: + return HFVersionInfo( + repo="policyengine/policyengine-us-data", + commit="abc123def456", + ) + + +@pytest.fixture +def sample_manifest( + sample_generations: dict[str, int], + sample_hf_info: HFVersionInfo, +) -> VersionManifest: + return VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + + +@pytest.fixture +def sample_registry( + sample_manifest: VersionManifest, +) -> VersionRegistry: + """A registry with one version entry.""" + return VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + + +@pytest.fixture +def mock_bucket() -> MagicMock: + bucket = MagicMock() + bucket.name = "policyengine-us-data" + return bucket diff --git a/policyengine_us_data/tests/fixtures/__init__.py b/policyengine_us_data/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/policyengine_us_data/tests/fixtures/test_version_manifest.py b/policyengine_us_data/tests/fixtures/test_version_manifest.py new file mode 100644 index 00000000..2678f031 --- /dev/null +++ b/policyengine_us_data/tests/fixtures/test_version_manifest.py @@ -0,0 +1,25 @@ +"""Helper functions for version manifest tests.""" + +import json +from unittest.mock import MagicMock + +from policyengine_us_data.utils.version_manifest import ( + VersionRegistry, +) + + +def make_mock_blob(generation: int) -> MagicMock: + blob = MagicMock() + blob.generation = generation + return blob + + +def setup_bucket_with_registry( + bucket: MagicMock, + registry: VersionRegistry, +) -> None: + """Configure a mock bucket to serve a registry.""" + registry_json = json.dumps(registry.to_dict()) + blob = MagicMock() + blob.download_as_text.return_value = registry_json + bucket.blob.return_value = blob diff --git a/policyengine_us_data/tests/test_version_manifest.py b/policyengine_us_data/tests/test_version_manifest.py new file mode 100644 index 00000000..b1a150bf --- /dev/null +++ b/policyengine_us_data/tests/test_version_manifest.py @@ -0,0 +1,1027 @@ +"""Tests for version manifest registry system.""" + +import json +from unittest.mock import MagicMock, patch, call + +import pytest +from google.api_core.exceptions import NotFound + +from policyengine_us_data.utils.version_manifest import ( + GCSVersionInfo, + VersionManifest, + VersionRegistry, + build_manifest, + upload_manifest, + get_current_version, + get_manifest, + list_versions, + download_versioned_file, + rollback, + get_data_manifest, + get_data_version, +) +from policyengine_us_data.tests.fixtures.test_version_manifest import ( + make_mock_blob, + setup_bucket_with_registry, +) + +_MOD = "policyengine_us_data.utils.version_manifest" + + +# -- VersionManifest serialization tests --------------------------- + + +class TestVersionManifestSerialization: + def test_to_dict(self, sample_manifest): + result = sample_manifest.to_dict() + + assert result["version"] == "1.72.3" + assert result["created_at"] == "2026-03-10T14:30:00Z" + assert result["hf"]["repo"] == ("policyengine/policyengine-us-data") + assert result["hf"]["commit"] == "abc123def456" + assert result["gcs"]["bucket"] == "policyengine-us-data" + assert result["gcs"]["generations"]["enhanced_cps_2024.h5"] == 1710203948123456 + + def test_from_dict(self, sample_manifest): + data = { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": { + "repo": ("policyengine/policyengine-us-data"), + "commit": "abc123def456", + }, + "gcs": { + "bucket": "policyengine-us-data", + "generations": { + "enhanced_cps_2024.h5": (1710203948123456), + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + }, + }, + } + result = VersionManifest.from_dict(data) + + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + assert result.hf.repo == ("policyengine/policyengine-us-data") + assert result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 + assert result.gcs.bucket == "policyengine-us-data" + + def test_roundtrip(self, sample_manifest): + roundtripped = VersionManifest.from_dict(sample_manifest.to_dict()) + + assert roundtripped.version == sample_manifest.version + assert roundtripped.created_at == sample_manifest.created_at + assert roundtripped.hf.repo == sample_manifest.hf.repo + assert roundtripped.hf.commit == sample_manifest.hf.commit + assert roundtripped.gcs.bucket == sample_manifest.gcs.bucket + assert roundtripped.gcs.generations == sample_manifest.gcs.generations + + def test_without_hf(self, sample_generations): + manifest = VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + data = manifest.to_dict() + assert data["hf"] is None + + roundtripped = VersionManifest.from_dict(data) + assert roundtripped.hf is None + assert roundtripped.gcs.generations == (sample_generations) + + def test_special_operation_omitted_by_default(self, sample_manifest): + data = sample_manifest.to_dict() + assert "special_operation" not in data + assert "roll_back_version" not in data + + def test_special_operation_included_when_set( + self, sample_generations, sample_hf_info + ): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + special_operation="roll-back", + roll_back_version="1.70.1", + ) + data = manifest.to_dict() + assert data["special_operation"] == "roll-back" + assert data["roll_back_version"] == "1.70.1" + + def test_special_operation_roundtrip(self, sample_generations, sample_hf_info): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + special_operation="roll-back", + roll_back_version="1.70.1", + ) + roundtripped = VersionManifest.from_dict(manifest.to_dict()) + assert roundtripped.special_operation == "roll-back" + assert roundtripped.roll_back_version == "1.70.1" + + def test_regular_manifest_has_no_special_operation( + self, + ): + data = { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 123}, + }, + } + result = VersionManifest.from_dict(data) + assert result.special_operation is None + assert result.roll_back_version is None + + +# -- VersionRegistry serialization tests --------------------------- + + +class TestVersionRegistrySerialization: + def test_to_dict(self, sample_registry): + result = sample_registry.to_dict() + + assert result["current"] == "1.72.3" + assert len(result["versions"]) == 1 + assert result["versions"][0]["version"] == "1.72.3" + + def test_from_dict(self, sample_manifest): + data = { + "current": "1.72.3", + "versions": [sample_manifest.to_dict()], + } + result = VersionRegistry.from_dict(data) + + assert result.current == "1.72.3" + assert len(result.versions) == 1 + assert result.versions[0].version == "1.72.3" + assert result.versions[0].hf.commit == "abc123def456" + + def test_roundtrip(self, sample_registry): + roundtripped = VersionRegistry.from_dict(sample_registry.to_dict()) + assert roundtripped.current == sample_registry.current + assert len(roundtripped.versions) == len(sample_registry.versions) + assert roundtripped.versions[0].version == "1.72.3" + + def test_get_version(self, sample_registry): + result = sample_registry.get_version("1.72.3") + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + + def test_get_version_not_found(self, sample_registry): + with pytest.raises(ValueError, match="not found"): + sample_registry.get_version("9.9.9") + + def test_empty_registry(self): + registry = VersionRegistry() + assert registry.current == "" + assert registry.versions == [] + + data = registry.to_dict() + assert data == {"current": "", "versions": []} + + +# -- build_manifest tests ------------------------------------------ + + +class TestBuildManifest: + @patch(f"{_MOD}._get_gcs_bucket") + def test_structure(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob_names = [ + "file_a.h5", + "file_b.h5", + "file_c.h5", + ] + mock_bucket.get_blob.side_effect = [ + make_mock_blob(100), + make_mock_blob(200), + make_mock_blob(300), + ] + + result = build_manifest("1.72.3", blob_names) + + assert isinstance(result, VersionManifest) + assert result.version == "1.72.3" + assert result.created_at.endswith("Z") + assert result.gcs.generations == { + "file_a.h5": 100, + "file_b.h5": 200, + "file_c.h5": 300, + } + assert result.gcs.bucket == "policyengine-us-data" + assert result.hf is None + + @patch(f"{_MOD}._get_gcs_bucket") + def test_with_subdirectories(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob_names = [ + "states/AL.h5", + "districts/CA-01.h5", + ] + mock_bucket.get_blob.side_effect = [ + make_mock_blob(111), + make_mock_blob(222), + ] + + result = build_manifest("1.72.3", blob_names) + + assert "states/AL.h5" in result.gcs.generations + assert "districts/CA-01.h5" in result.gcs.generations + assert result.gcs.generations["states/AL.h5"] == 111 + assert result.gcs.generations["districts/CA-01.h5"] == 222 + + @patch(f"{_MOD}._get_gcs_bucket") + def test_with_hf_info( + self, + mock_get_bucket, + mock_bucket, + sample_hf_info, + ): + mock_get_bucket.return_value = mock_bucket + mock_bucket.get_blob.return_value = make_mock_blob(999) + + result = build_manifest( + "1.72.3", + ["file.h5"], + hf_info=sample_hf_info, + ) + + assert result.hf is not None + assert result.hf.commit == "abc123def456" + assert result.hf.repo == ("policyengine/policyengine-us-data") + + @patch(f"{_MOD}._get_gcs_bucket") + def test_missing_blob_raises(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + mock_bucket.get_blob.return_value = None + + with pytest.raises(ValueError, match="not found"): + build_manifest("1.72.3", ["missing.h5"]) + + +# -- upload_manifest tests ----------------------------------------- + + +class TestUploadManifest: + def _setup_empty_registry(self, bucket): + """Mock bucket with no existing registry.""" + written = {} + + def mock_blob(name): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.side_effect = NotFound("Not found") + written[name] = b + return b + b = MagicMock() + b.name = name + written[name] = b + return b + + bucket.blob.side_effect = mock_blob + return written + + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_writes_registry_to_gcs( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + written = self._setup_empty_registry(mock_bucket) + + upload_manifest(sample_manifest) + + assert "version_manifest.json" in written + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 1 + assert registry_data["versions"][0]["version"] == "1.72.3" + + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_includes_hf_commit( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + written = self._setup_empty_registry(mock_bucket) + + upload_manifest(sample_manifest) + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["versions"][0]["hf"]["commit"] == "abc123def456" + + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_appends_to_existing_registry( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + older = VersionManifest( + version="1.72.2", + created_at="2026-03-09T10:00:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations={"old.h5": 111}, + ), + ) + existing_registry = VersionRegistry(current="1.72.2", versions=[older]) + existing_json = json.dumps(existing_registry.to_dict()) + written = {} + + def mock_blob(name): + b = MagicMock() + b.name = name + b.download_as_text.return_value = existing_json + written[name] = b + return b + + mock_bucket.blob.side_effect = mock_blob + + upload_manifest(sample_manifest) + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 2 + assert registry_data["versions"][0]["version"] == "1.72.3" + assert registry_data["versions"][1]["version"] == "1.72.2" + + @patch(f"{_MOD}.os") + @patch(f"{_MOD}.HfApi") + @patch(f"{_MOD}._get_gcs_bucket") + def test_always_uploads_to_hf( + self, + mock_get_bucket, + mock_hf_api_cls, + mock_os, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + mock_os.environ.get.return_value = "fake_token" + mock_os.unlink = MagicMock() + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + upload_manifest(sample_manifest) + + mock_api.upload_file.assert_called_once() + call_kwargs = mock_api.upload_file.call_args.kwargs + assert call_kwargs["path_in_repo"] == ("version_manifest.json") + assert call_kwargs["repo_id"] == ("policyengine/policyengine-us-data") + + +# -- get_current_version tests ------------------------------------- + + +class TestGetCurrentVersion: + @patch(f"{_MOD}._get_gcs_bucket") + def test_returns_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket + setup_bucket_with_registry(mock_bucket, sample_registry) + + result = get_current_version() + + assert result == "1.72.3" + mock_bucket.blob.assert_called_with("version_manifest.json") + + @patch(f"{_MOD}._get_gcs_bucket") + def test_no_registry_returns_none(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + result = get_current_version() + + assert result is None + + +# -- get_manifest tests --------------------------------------------- + + +class TestGetManifest: + @patch(f"{_MOD}._get_gcs_bucket") + def test_specific_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket + setup_bucket_with_registry(mock_bucket, sample_registry) + + result = get_manifest("1.72.3") + + assert isinstance(result, VersionManifest) + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + assert result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 + + @patch(f"{_MOD}._get_gcs_bucket") + def test_nonexistent_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket + setup_bucket_with_registry(mock_bucket, sample_registry) + + with pytest.raises(ValueError, match="not found"): + get_manifest("9.9.9") + + @patch(f"{_MOD}._get_gcs_bucket") + def test_no_registry_raises(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + with pytest.raises(ValueError, match="not found"): + get_manifest("1.72.3") + + +# -- list_versions tests ------------------------------------------- + + +class TestListVersions: + @patch(f"{_MOD}._get_gcs_bucket") + def test_returns_sorted(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + v1 = VersionManifest( + version="1.72.1", + created_at="t1", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 1}), + ) + v2 = VersionManifest( + version="1.72.3", + created_at="t2", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 2}), + ) + v3 = VersionManifest( + version="1.72.2", + created_at="t3", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 3}), + ) + registry = VersionRegistry(current="1.72.3", versions=[v2, v3, v1]) + setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions() + + assert result == [ + "1.72.1", + "1.72.2", + "1.72.3", + ] + + @patch(f"{_MOD}._get_gcs_bucket") + def test_empty(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + registry = VersionRegistry() + setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions() + + assert result == [] + + @patch(f"{_MOD}._get_gcs_bucket") + def test_semver_ordering(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + versions = [ + "1.100.0", + "2.0.0", + "1.9.0", + "1.10.0", + ] + manifests = [ + VersionManifest( + version=v, + created_at="t", + hf=None, + gcs=GCSVersionInfo( + bucket="b", + generations={"f.h5": i}, + ), + ) + for i, v in enumerate(versions) + ] + registry = VersionRegistry( + current="2.0.0", + versions=manifests, + ) + setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions() + + assert result == [ + "1.9.0", + "1.10.0", + "1.100.0", + "2.0.0", + ] + + +# -- download_versioned_file tests --------------------------------- + + +class TestDownloadVersionedFile: + @patch(f"{_MOD}._get_gcs_bucket") + def test_downloads_correct_generation( + self, + mock_get_bucket, + mock_bucket, + sample_manifest, + tmp_path, + ): + mock_get_bucket.return_value = mock_bucket + registry = VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + registry_json = json.dumps(registry.to_dict()) + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + blob = MagicMock() + blob.download_as_text.return_value = registry_json + return blob + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + local_path = str(tmp_path / "AL.h5") + download_versioned_file( + "states/AL.h5", + "1.72.3", + local_path, + ) + + calls = mock_bucket.blob.call_args_list + gen_call = [ + c + for c in calls + if c + == call( + "states/AL.h5", + generation=1710203948345678, + ) + ] + assert len(gen_call) == 1 + + @patch(f"{_MOD}._get_gcs_bucket") + def test_file_not_in_manifest( + self, + mock_get_bucket, + mock_bucket, + sample_manifest, + tmp_path, + ): + mock_get_bucket.return_value = mock_bucket + registry = VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + setup_bucket_with_registry(mock_bucket, registry) + + with pytest.raises(ValueError, match="not found"): + download_versioned_file( + "nonexistent.h5", + "1.72.3", + str(tmp_path / "out.h5"), + ) + + +# -- rollback tests ------------------------------------------------- + + +class TestRollback: + @patch(f"{_MOD}.CommitOperationAdd") + @patch(f"{_MOD}.hf_hub_download") + @patch(f"{_MOD}.HfApi") + @patch(f"{_MOD}.os") + @patch(f"{_MOD}._get_gcs_bucket") + def test_creates_new_version_with_old_data( + self, + mock_get_bucket, + mock_os, + mock_hf_api_cls, + mock_hf_download, + mock_commit_op, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + mock_os.environ.get.return_value = "fake_token" + mock_os.path.join = lambda *args: "/".join(args) + mock_os.unlink = MagicMock() + + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + commit_info = MagicMock() + commit_info.oid = "new_commit_sha" + mock_api.create_commit.return_value = commit_info + + registry = VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + registry_json = json.dumps(registry.to_dict()) + written = {} + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.return_value = registry_json + written[name] = b + return b + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + new_gen_counter = iter([50001, 50002, 50003]) + + def mock_get_blob(name): + blob = MagicMock() + blob.generation = next(new_gen_counter) + return blob + + mock_bucket.get_blob.side_effect = mock_get_blob + + result = rollback( + target_version="1.72.3", + new_version="1.73.0", + ) + + assert isinstance(result, VersionManifest) + assert result.version == "1.73.0" + assert result.special_operation == "roll-back" + assert result.roll_back_version == "1.72.3" + + assert mock_bucket.copy_blob.call_count == 3 + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.73.0" + assert len(registry_data["versions"]) == 2 + assert registry_data["versions"][0]["version"] == "1.73.0" + assert registry_data["versions"][0]["special_operation"] == "roll-back" + + mock_api.create_commit.assert_called_once() + commit_msg = mock_api.create_commit.call_args.kwargs["commit_message"] + assert "1.72.3" in commit_msg + assert "1.73.0" in commit_msg + mock_api.create_tag.assert_called_once() + + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_rollback_without_hf( + self, + mock_get_bucket, + mock_hf_upload, + mock_bucket, + ): + mock_get_bucket.return_value = mock_bucket + + old_manifest = VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations={"file.h5": 111}, + ), + ) + registry = VersionRegistry( + current="1.72.3", + versions=[old_manifest], + ) + registry_json = json.dumps(registry.to_dict()) + written = {} + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.return_value = registry_json + written[name] = b + return b + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + restored_blob = MagicMock() + restored_blob.generation = 222 + mock_bucket.get_blob.return_value = restored_blob + + result = rollback( + target_version="1.72.3", + new_version="1.73.0", + ) + + assert result.version == "1.73.0" + assert result.hf is None + assert result.special_operation == "roll-back" + assert mock_bucket.copy_blob.call_count == 1 + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + assert registry_data["versions"][0]["hf"] is None + + @patch(f"{_MOD}._get_gcs_bucket") + def test_nonexistent_version(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + with pytest.raises(ValueError, match="not found"): + rollback( + target_version="9.9.9", + new_version="9.10.0", + ) + + +# -- upload_files_to_gcs return value test -------------------------- + + +class TestUploadFilesToGcsReturnsGenerations: + @patch("policyengine_us_data.utils.data_upload.google.auth") + @patch("policyengine_us_data.utils.data_upload.storage") + def test_returns_generations(self, mock_storage, mock_auth, tmp_path): + from policyengine_us_data.utils.data_upload import ( + upload_files_to_gcs, + ) + + mock_auth.default.return_value = ( + MagicMock(), + "project", + ) + mock_client = MagicMock() + mock_storage.Client.return_value = mock_client + mock_bucket = MagicMock() + mock_client.bucket.return_value = mock_bucket + + file_a = tmp_path / "file_a.h5" + file_a.write_bytes(b"data_a") + file_b = tmp_path / "file_b.h5" + file_b.write_bytes(b"data_b") + + blob_a = MagicMock() + blob_a.generation = 99999 + blob_b = MagicMock() + blob_b.generation = 88888 + + mock_bucket.blob.side_effect = [blob_a, blob_b] + + result = upload_files_to_gcs( + files=[str(file_a), str(file_b)], + version="1.72.3", + ) + + assert result == { + "file_a.h5": 99999, + "file_b.h5": 88888, + } + + +# -- End-to-end upload creates registry test ------------------------ + + +class TestEndToEndUploadCreatesRegistry: + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + @patch("policyengine_us_data.utils.data_upload.google.auth") + @patch("policyengine_us_data.utils.data_upload.storage") + @patch("policyengine_us_data.utils.data_upload.HfApi") + @patch("policyengine_us_data.utils.data_upload.os") + def test_creates_registry( + self, + mock_os, + mock_hf_api_cls, + mock_du_storage, + mock_du_auth, + mock_get_bucket, + mock_hf_upload, + tmp_path, + ): + from policyengine_us_data.utils.data_upload import ( + upload_data_files, + ) + + mock_du_auth.default.return_value = ( + MagicMock(), + "project", + ) + mock_os.environ.get.return_value = "fake_token" + + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + commit_info = MagicMock() + commit_info.oid = "abc123" + mock_api.create_commit.return_value = commit_info + + mock_client = MagicMock() + mock_du_storage.Client.return_value = mock_client + mock_bucket = MagicMock() + mock_bucket.name = "policyengine-us-data" + mock_client.bucket.return_value = mock_bucket + mock_get_bucket.return_value = mock_bucket + + blob_data = MagicMock() + blob_data.generation = 55555 + + written = {} + + def mock_blob(name): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.side_effect = NotFound("Not found") + written[name] = b + return b + return blob_data + + mock_bucket.blob.side_effect = mock_blob + + test_file = tmp_path / "test.h5" + test_file.write_bytes(b"test_data") + + upload_data_files( + files=[str(test_file)], + version="1.72.3", + ) + + assert "version_manifest.json" in written + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 1 + assert registry_data["versions"][0]["hf"]["commit"] == "abc123" + assert "test.h5" in registry_data["versions"][0]["gcs"]["generations"] + + +# -- Consumer API tests -------------------------------------------- + + +class TestGetDataManifest: + def setup_method(self): + import policyengine_us_data.utils.version_manifest as mod + + mod._cached_registry = None + + def teardown_method(self): + import policyengine_us_data.utils.version_manifest as mod + + mod._cached_registry = None + + @patch(f"{_MOD}.hf_hub_download") + def test_returns_registry(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": ("2026-03-10T14:30:00Z"), + "hf": { + "repo": ("policyengine/policyengine-us-data"), + "commit": "abc123", + }, + "gcs": { + "bucket": ("policyengine-us-data"), + "generations": {"file.h5": 12345}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + result = get_data_manifest() + + assert isinstance(result, VersionRegistry) + assert result.current == "1.72.3" + assert len(result.versions) == 1 + assert result.versions[0].hf.commit == "abc123" + mock_download.assert_called_once_with( + repo_id=("policyengine/policyengine-us-data"), + repo_type="model", + filename="version_manifest.json", + ) + + @patch(f"{_MOD}.hf_hub_download") + def test_caches_result(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": ("2026-03-10T14:30:00Z"), + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 1}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + first = get_data_manifest() + second = get_data_manifest() + + assert first is second + assert mock_download.call_count == 1 + + @patch(f"{_MOD}.hf_hub_download") + def test_get_data_version(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": ("2026-03-10T14:30:00Z"), + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 1}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + result = get_data_version() + + assert result == "1.72.3" diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index c8a50036..2dfb24ed 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -6,16 +6,19 @@ CommitOperationDelete, hf_hub_download, ) -from huggingface_hub.errors import RevisionNotFoundError from google.cloud import storage from pathlib import Path from importlib import metadata import google.auth import httpx -import json import logging import os +from policyengine_us_data.utils.version_manifest import ( + GCS_BUCKET_NAME, + HF_REPO_NAME, + HF_REPO_TYPE, +) from tenacity import ( retry, stop_after_attempt, @@ -31,36 +34,59 @@ def upload_data_files( files: List[str], - gcs_bucket_name: str = "policyengine-us-data", - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", - version: str = None, -): + gcs_bucket_name: str = GCS_BUCKET_NAME, + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, + version: Optional[str] = None, +) -> None: + from policyengine_us_data.utils.version_manifest import ( + GCSVersionInfo, + HFVersionInfo, + VersionManifest, + upload_manifest, + _utc_now_iso, + ) + if version is None: version = metadata.version("policyengine-us-data") - upload_files_to_hf( + hf_commit = upload_files_to_hf( files=files, version=version, hf_repo_name=hf_repo_name, hf_repo_type=hf_repo_type, ) - upload_files_to_gcs( + generations = upload_files_to_gcs( files=files, version=version, gcs_bucket_name=gcs_bucket_name, ) + manifest = VersionManifest( + version=version, + created_at=_utc_now_iso(), + hf=HFVersionInfo(repo=HF_REPO_NAME, commit=hf_commit), + gcs=GCSVersionInfo( + bucket=GCS_BUCKET_NAME, + generations=generations, + ), + ) + upload_manifest(manifest) + logging.info(f"Created version manifest for {version}.") + def upload_files_to_hf( files: List[str], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", -): - """ - Upload files to Hugging Face repository and tag the commit with the version. + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, +) -> str: + """Upload files to Hugging Face repository and tag the + commit with the version. + + Returns: + The commit SHA (oid) of the created commit. """ api = HfApi() hf_operations = [] @@ -87,39 +113,37 @@ def upload_files_to_hf( ) logging.info(f"Uploaded files to Hugging Face repository {hf_repo_name}.") - # Tag commit with version - try: - api.create_tag( - token=token, - repo_id=hf_repo_name, - tag=version, - revision=commit_info.oid, - repo_type=hf_repo_type, - ) - logging.info( - f"Tagged commit with {version} in Hugging Face repository {hf_repo_name}." - ) - except Exception as e: - if "Tag reference exists already" in str(e) or "409" in str(e): - logging.warning( - f"Tag {version} already exists in {hf_repo_name}. Skipping tag creation." - ) - else: - raise + api.create_tag( + token=token, + repo_id=hf_repo_name, + tag=version, + revision=commit_info.oid, + repo_type=hf_repo_type, + exist_ok=True, + ) + logging.info( + f"Tagged commit with {version} in Hugging Face repository {hf_repo_name}." + ) + + return commit_info.oid def upload_files_to_gcs( files: List[str], version: str, - gcs_bucket_name: str = "policyengine-us-data", -): - """ - Upload files to Google Cloud Storage and set metadata with the version. + gcs_bucket_name: str = GCS_BUCKET_NAME, +) -> Dict[str, int]: + """Upload files to Google Cloud Storage and set metadata + with the version. + + Returns: + Dict mapping blob name to its GCS generation number. """ credentials, project_id = google.auth.default() storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) + generations: Dict[str, int] = {} for file_path in files: file_path = Path(file_path) blob = bucket.blob(file_path.name) @@ -129,28 +153,39 @@ def upload_files_to_gcs( # Set metadata blob.metadata = {"version": version} blob.patch() + + # Read back generation number for manifest + blob.reload() + generations[file_path.name] = blob.generation logging.info( - f"Set metadata for {file_path.name} in GCS bucket {gcs_bucket_name}." + f"Set metadata for {file_path.name} in GCS " + f"bucket {gcs_bucket_name} " + f"(generation {blob.generation})." ) + return generations + def upload_local_area_file( file_path: str, subdirectory: str, - gcs_bucket_name: str = "policyengine-us-data", - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + gcs_bucket_name: str = GCS_BUCKET_NAME, + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, version: str = None, skip_hf: bool = False, -): - """ - Upload a single local area H5 file to a subdirectory. +) -> int: + """Upload a single local area H5 file to a subdirectory. Supports states/, districts/, cities/, and national/. Uploads to both GCS and Hugging Face. Args: - skip_hf: If True, skip HuggingFace upload (for batched uploads later) + skip_hf: If True, skip HuggingFace upload (for batched + uploads later) + + Returns: + The GCS generation number of the uploaded blob. """ if version is None: version = metadata.version("policyengine-us-data") @@ -169,10 +204,15 @@ def upload_local_area_file( blob.upload_from_filename(file_path) blob.metadata = {"version": version} blob.patch() - logging.info(f"Uploaded {blob_name} to GCS bucket {gcs_bucket_name}.") + blob.reload() + generation = blob.generation + logging.info( + f"Uploaded {blob_name} to GCS bucket " + f"{gcs_bucket_name} (generation {generation})." + ) if skip_hf: - return + return generation # Upload to Hugging Face with subdirectory token = os.environ.get("HUGGING_FACE_TOKEN") @@ -183,17 +223,21 @@ def upload_local_area_file( repo_id=hf_repo_name, repo_type=hf_repo_type, token=token, - commit_message=f"Upload {subdirectory}/{file_path.name} for version {version}", + commit_message=( + f"Upload {subdirectory}/{file_path.name} for version {version}" + ), ) logging.info( f"Uploaded {subdirectory}/{file_path.name} to Hugging Face {hf_repo_name}." ) + return generation + def upload_local_area_batch_to_hf( files_with_subdirs: List[tuple], - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, version: str = None, ): """ @@ -278,8 +322,8 @@ def hf_create_commit_with_retry( def upload_to_staging_hf( files_with_paths: List[Tuple[Path, str]], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, batch_size: int = 50, ) -> int: """ @@ -338,8 +382,8 @@ def upload_to_staging_hf( def promote_staging_to_production_hf( files: List[str], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, ) -> int: """ Atomically promote files from staging/ to production paths. @@ -406,8 +450,8 @@ def promote_staging_to_production_hf( def cleanup_staging_hf( files: List[str], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, ) -> int: """ Clean up staging folder after successful promotion. diff --git a/policyengine_us_data/utils/version_manifest.py b/policyengine_us_data/utils/version_manifest.py new file mode 100644 index 00000000..b51df586 --- /dev/null +++ b/policyengine_us_data/utils/version_manifest.py @@ -0,0 +1,559 @@ +""" +Version registry for semver-based dataset versioning. + +Provides typed structures and functions for versioned uploads, +downloads, and rollbacks across GCS and Hugging Face. All +versions are tracked in a single registry file +(version_manifest.json) on both backends. +""" + +import json +import logging +import os +import tempfile +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +import google.auth +from packaging.version import Version +from google.api_core.exceptions import NotFound +from google.cloud import storage +from huggingface_hub import ( + HfApi, + CommitOperationAdd, + hf_hub_download, +) + +# -- Configuration ------------------------------------------------- + +REGISTRY_BLOB = "version_manifest.json" +GCS_BUCKET_NAME = "policyengine-us-data" +HF_REPO_NAME = "policyengine/policyengine-us-data" +HF_REPO_TYPE = "model" + + +# -- Types --------------------------------------------------------- + + +@dataclass +class HFVersionInfo: + """Hugging Face backend location for a version.""" + + repo: str + commit: str + + def to_dict(self) -> dict[str, str]: + return {"repo": self.repo, "commit": self.commit} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "HFVersionInfo": + return cls(repo=data["repo"], commit=data["commit"]) + + +@dataclass +class GCSVersionInfo: + """GCS backend location for a version.""" + + bucket: str + generations: dict[str, int] + + def to_dict(self) -> dict[str, Any]: + return { + "bucket": self.bucket, + "generations": self.generations, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "GCSVersionInfo": + return cls( + bucket=data["bucket"], + generations=data["generations"], + ) + + +@dataclass +class VersionManifest: + """Single version entry tying semver to backend + identifiers. + + Consumers interact only with the semver version string. + HF commit SHAs and GCS generation numbers are internal + implementation details resolved by this manifest. + """ + + version: str + created_at: str + hf: Optional[HFVersionInfo] + gcs: GCSVersionInfo + special_operation: Optional[str] = None + roll_back_version: Optional[str] = None + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "version": self.version, + "created_at": self.created_at, + "hf": self.hf.to_dict() if self.hf else None, + "gcs": self.gcs.to_dict(), + } + if self.special_operation is not None: + result["special_operation"] = self.special_operation + if self.roll_back_version is not None: + result["roll_back_version"] = self.roll_back_version + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VersionManifest": + hf_data = data.get("hf") + return cls( + version=data["version"], + created_at=data["created_at"], + hf=(HFVersionInfo.from_dict(hf_data) if hf_data else None), + gcs=GCSVersionInfo.from_dict(data["gcs"]), + special_operation=data.get("special_operation"), + roll_back_version=data.get("roll_back_version"), + ) + + +@dataclass +class VersionRegistry: + """Registry of all dataset versions. + + Contains a pointer to the current version and a list of + all version manifests (most recent first). + """ + + current: str = "" + versions: list[VersionManifest] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "current": self.current, + "versions": [v.to_dict() for v in self.versions], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VersionRegistry": + return cls( + current=data["current"], + versions=[VersionManifest.from_dict(v) for v in data["versions"]], + ) + + def get_version(self, version: str) -> VersionManifest: + """Look up a specific version entry. + + Args: + version: Semver version string. + + Returns: + The matching VersionManifest. + + Raises: + ValueError: If the version is not in the + registry. + """ + for v in self.versions: + if v.version == version: + return v + available = [v.version for v in self.versions[:10]] + raise ValueError( + f"Version '{version}' not found in registry. " + f"Available versions: {available}" + ) + + +# -- Internal helpers ---------------------------------------------- + + +def _utc_now_iso() -> str: + """Return the current UTC time as an ISO 8601 string.""" + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + +def _get_gcs_bucket() -> storage.Bucket: + """Return an authenticated GCS bucket handle.""" + credentials, project_id = google.auth.default() + client = storage.Client(credentials=credentials, project=project_id) + return client.bucket(GCS_BUCKET_NAME) + + +def _read_registry_from_gcs( + bucket: storage.Bucket, +) -> VersionRegistry: + """Read the version registry from GCS. + + Returns an empty registry if no registry exists yet. + """ + blob = bucket.blob(REGISTRY_BLOB) + try: + content = blob.download_as_text() + except NotFound: + return VersionRegistry() + return VersionRegistry.from_dict(json.loads(content)) + + +def _upload_registry_to_gcs( + bucket: storage.Bucket, + registry: VersionRegistry, +) -> None: + """Write the version registry to GCS.""" + data = json.dumps(registry.to_dict(), indent=2) + blob = bucket.blob(REGISTRY_BLOB) + blob.upload_from_string(data, content_type="application/json") + logging.info(f"Uploaded registry to GCS (current={registry.current}).") + + +def _upload_registry_to_hf( + registry: VersionRegistry, +) -> None: + """Write the version registry to Hugging Face.""" + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + data = json.dumps(registry.to_dict(), indent=2) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write(data) + tmp_path = f.name + + try: + api.upload_file( + path_or_fileobj=tmp_path, + path_in_repo=REGISTRY_BLOB, + repo_id=HF_REPO_NAME, + repo_type=HF_REPO_TYPE, + token=token, + commit_message=(f"Update version registry (current={registry.current})"), + ) + logging.info(f"Uploaded {REGISTRY_BLOB} to HF repo {HF_REPO_NAME}.") + finally: + os.unlink(tmp_path) + + +def _restore_gcs_generations( + bucket: storage.Bucket, + old_generations: dict[str, int], +) -> dict[str, int]: + """Copy old GCS generation blobs to live paths. + + Args: + bucket: GCS bucket containing the blobs. + old_generations: Map of blob path to old generation + number. + + Returns: + Map of blob path to new generation number. + """ + new_generations: dict[str, int] = {} + for file_path, generation in old_generations.items(): + source_blob = bucket.blob(file_path, generation=generation) + bucket.copy_blob(source_blob, bucket, file_path) + restored_blob = bucket.get_blob(file_path) + new_generations[file_path] = restored_blob.generation + logging.info( + f"Restored {file_path}: generation " + f"{generation} -> {restored_blob.generation}." + ) + return new_generations + + +def _restore_hf_commit( + old_manifest: VersionManifest, + new_version: str, +) -> str: + """Re-upload old HF data as a new commit and tag it. + + Args: + old_manifest: The manifest of the version being + restored. + new_version: The new semver version string for + tagging. + + Returns: + The commit SHA of the new HF commit. + """ + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + target_version = old_manifest.version + + operations = [] + with tempfile.TemporaryDirectory() as tmpdir: + for file_path in old_manifest.gcs.generations: + hf_hub_download( + repo_id=old_manifest.hf.repo, + repo_type=HF_REPO_TYPE, + filename=file_path, + revision=old_manifest.hf.commit, + local_dir=tmpdir, + token=token, + ) + downloaded = os.path.join(tmpdir, file_path) + operations.append( + CommitOperationAdd( + path_in_repo=file_path, + path_or_fileobj=downloaded, + ) + ) + + commit_info = api.create_commit( + token=token, + repo_id=HF_REPO_NAME, + operations=operations, + repo_type=HF_REPO_TYPE, + commit_message=(f"Roll back to {target_version} as {new_version}"), + ) + + api.create_tag( + token=token, + repo_id=HF_REPO_NAME, + tag=new_version, + revision=commit_info.oid, + repo_type=HF_REPO_TYPE, + exist_ok=True, + ) + + return commit_info.oid + + +# -- Public API ---------------------------------------------------- + + +def build_manifest( + version: str, + blob_names: list[str], + hf_info: Optional[HFVersionInfo] = None, +) -> VersionManifest: + """Build a version manifest by reading generation + numbers from uploaded blobs. + + Args: + version: Semver version string. + blob_names: List of blob paths to include. + hf_info: Optional HF backend info to include. + + Returns: + A VersionManifest with generation numbers for + each blob. + """ + bucket = _get_gcs_bucket() + generations: dict[str, int] = {} + for name in blob_names: + blob = bucket.get_blob(name) + if blob is None: + raise ValueError( + f"Blob '{name}' not found in bucket '{bucket.name}' after upload." + ) + generations[name] = blob.generation + + return VersionManifest( + version=version, + created_at=_utc_now_iso(), + hf=hf_info, + gcs=GCSVersionInfo( + bucket=bucket.name, + generations=generations, + ), + ) + + +def upload_manifest( + manifest: VersionManifest, +) -> None: + """Append a version manifest to the registry and + upload to both GCS and HF. + + Reads the existing registry from GCS (or starts fresh), + prepends the new manifest, updates the current pointer, + and writes the registry to both backends. + + Args: + manifest: The version manifest to add. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + registry.versions.insert(0, manifest) + registry.current = manifest.version + _upload_registry_to_gcs(bucket, registry) + _upload_registry_to_hf(registry) + + +def get_current_version() -> Optional[str]: + """Get the current version from the registry. + + Returns: + The current semver version string, or None if no + registry exists. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + if not registry.current: + return None + return registry.current + + +def get_manifest(version: str) -> VersionManifest: + """Get the manifest for a specific version. + + Args: + version: Semver version string. + + Returns: + The deserialized VersionManifest. + + Raises: + ValueError: If the version is not in the registry. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + return registry.get_version(version) + + +def list_versions() -> list[str]: + """List all available versions. + + Returns: + Sorted list of semver version strings. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + return sorted( + (v.version for v in registry.versions), + key=Version, + ) + + +def download_versioned_file( + file_path: str, + version: str, + local_path: str, +) -> str: + """Download a specific file at a specific version. + + Args: + file_path: Path of the file within the bucket. + version: Semver version string. + local_path: Local path to save the file to. + + Returns: + The local path where the file was saved. + + Raises: + ValueError: If the version or file is not found. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + manifest = registry.get_version(version) + + if file_path not in manifest.gcs.generations: + raise ValueError( + f"File '{file_path}' not found in manifest " + f"for version '{version}'. Available files: " + f"{list(manifest.gcs.generations.keys())[:10]}" + "..." + ) + + generation = manifest.gcs.generations[file_path] + blob = bucket.blob(file_path, generation=generation) + + Path(local_path).parent.mkdir(parents=True, exist_ok=True) + blob.download_to_filename(local_path) + + logging.info( + f"Downloaded {file_path} at version {version} " + f"(generation {generation}) to {local_path}." + ) + return local_path + + +def rollback( + target_version: str, + new_version: str, +) -> VersionManifest: + """Roll back by releasing a new version with old data. + + Treats rollback as a new release: data from + target_version is copied to the live paths (creating + new GCS generations), a new HF commit is created with + the old data, and a new manifest is published under + new_version with special_operation="roll-back". + + Args: + target_version: Semver version to roll back to. + new_version: New semver version to publish. + + Returns: + The new VersionManifest for the rollback release. + + Raises: + ValueError: If target_version is not in the + registry. + """ + bucket = _get_gcs_bucket() + old_manifest = _read_registry_from_gcs(bucket).get_version(target_version) + + new_gens = _restore_gcs_generations(bucket, old_manifest.gcs.generations) + hf_commit = ( + _restore_hf_commit(old_manifest, new_version) if old_manifest.hf else None + ) + + manifest = VersionManifest( + version=new_version, + created_at=_utc_now_iso(), + hf=(HFVersionInfo(repo=HF_REPO_NAME, commit=hf_commit) if hf_commit else None), + gcs=GCSVersionInfo( + bucket=GCS_BUCKET_NAME, + generations=new_gens, + ), + special_operation="roll-back", + roll_back_version=target_version, + ) + upload_manifest(manifest) + + logging.info( + f"Rolled back to {target_version} as new " + f"version {new_version}. " + f"Restored {len(new_gens)} files." + ) + return manifest + + +# -- Consumer API -------------------------------------------------- + +_cached_registry: Optional[VersionRegistry] = None + + +def get_data_manifest() -> VersionRegistry: + """Get the full version registry from HF. + + Fetches version_manifest.json from the Hugging Face + repo and returns it as a VersionRegistry. The result + is cached in memory after the first call. + + Returns: + The full VersionRegistry. + """ + global _cached_registry + if _cached_registry is not None: + return _cached_registry + + local_path = hf_hub_download( + repo_id=HF_REPO_NAME, + repo_type=HF_REPO_TYPE, + filename=REGISTRY_BLOB, + ) + with open(local_path) as f: + data = json.load(f) + + _cached_registry = VersionRegistry.from_dict(data) + return _cached_registry + + +def get_data_version() -> str: + """Get the current deployed data version string. + + Convenience wrapper around get_data_manifest(). + + Returns: + The current semver version string. + """ + return get_data_manifest().current diff --git a/pyproject.toml b/pyproject.toml index 46e23bfa..1a15adae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "xlrd>=2.0.2", "spm-calculator>=0.1.0", "tenacity>=8.0.0", + "packaging>=21.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index e554b94e..92492e51 100644 --- a/uv.lock +++ b/uv.lock @@ -610,6 +610,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, + { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -617,6 +618,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, + { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -1820,6 +1822,7 @@ dependencies = [ { name = "microdf-python" }, { name = "microimpute" }, { name = "openpyxl" }, + { name = "packaging" }, { name = "pandas" }, { name = "pip-system-certs" }, { name = "policyengine-core" }, @@ -1871,6 +1874,7 @@ requires-dist = [ { name = "microdf-python", specifier = ">=1.2.1" }, { name = "microimpute", specifier = ">=1.15.1" }, { name = "openpyxl", specifier = ">=3.1.5" }, + { name = "packaging", specifier = ">=21.0" }, { name = "pandas", specifier = ">=2.3.1" }, { name = "pip-system-certs", specifier = ">=3.0" }, { name = "policyengine-core", specifier = ">=3.23.6" },