From 9e3b66ad7e65c8988c04eee414908e6a8c11199b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Mar 2026 18:40:24 +0100 Subject: [PATCH 1/8] Add unified version registry for GCS and HF dataset versioning Introduces a single version_manifest.json that tracks all dataset versions across both GCS and Hugging Face, enabling programmatic version queries and rollback. Adds consumer API (get_data_version, get_data_manifest) exported at package level. Co-Authored-By: Claude Opus 4.6 --- policyengine_us_data/__init__.py | 1 + .../tests/test_gcs_version.py | 927 ++++++++++++++++++ policyengine_us_data/utils/data_upload.py | 125 ++- policyengine_us_data/utils/gcs_version.py | 574 +++++++++++ 4 files changed, 1601 insertions(+), 26 deletions(-) create mode 100644 policyengine_us_data/tests/test_gcs_version.py create mode 100644 policyengine_us_data/utils/gcs_version.py diff --git a/policyengine_us_data/__init__.py b/policyengine_us_data/__init__.py index 173835349..e2707f030 100644 --- a/policyengine_us_data/__init__.py +++ b/policyengine_us_data/__init__.py @@ -1,2 +1,3 @@ from .datasets import * from .geography import ZIP_CODE_DATASET +from .utils.gcs_version import get_data_version, get_data_manifest diff --git a/policyengine_us_data/tests/test_gcs_version.py b/policyengine_us_data/tests/test_gcs_version.py new file mode 100644 index 000000000..27d91094c --- /dev/null +++ b/policyengine_us_data/tests/test_gcs_version.py @@ -0,0 +1,927 @@ +"""Tests for GCS version registry system.""" + +import json +from unittest.mock import MagicMock, patch, call + +import pytest +from google.api_core.exceptions import NotFound + +from policyengine_us_data.utils.gcs_version import ( + HFVersionInfo, + GCSVersionInfo, + VersionManifest, + VersionRegistry, + build_manifest, + upload_manifest, + get_current_version, + get_manifest, + list_versions, + download_versioned_file, + rollback, + get_data_manifest, + get_data_version, +) + +# -- Fixtures ------------------------------------------------------- + + +@pytest.fixture +def sample_generations(): + return { + "enhanced_cps_2024.h5": 1710203948123456, + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + } + + +@pytest.fixture +def sample_hf_info(): + return HFVersionInfo( + repo="policyengine/policyengine-us-data", + commit="abc123def456", + ) + + +@pytest.fixture +def sample_manifest(sample_generations, sample_hf_info): + return VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + + +@pytest.fixture +def sample_registry(sample_manifest): + """A registry with one version entry.""" + return VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + + +@pytest.fixture +def mock_bucket(): + bucket = MagicMock() + bucket.name = "policyengine-us-data" + return bucket + + +def _make_mock_blob(generation: int): + blob = MagicMock() + blob.generation = generation + return blob + + +def _setup_bucket_with_registry(bucket, registry): + """Configure a mock bucket to serve a registry.""" + registry_json = json.dumps(registry.to_dict()) + blob = MagicMock() + blob.download_as_text.return_value = registry_json + bucket.blob.return_value = blob + + +# -- VersionManifest serialization tests --------------------------- + + +class TestVersionManifestSerialization: + def test_to_dict(self, sample_manifest): + result = sample_manifest.to_dict() + + assert result["version"] == "1.72.3" + assert result["created_at"] == "2026-03-10T14:30:00Z" + assert result["hf"]["repo"] == ("policyengine/policyengine-us-data") + assert result["hf"]["commit"] == "abc123def456" + assert result["gcs"]["bucket"] == "policyengine-us-data" + assert ( + result["gcs"]["generations"]["enhanced_cps_2024.h5"] + == 1710203948123456 + ) + + def test_from_dict(self, sample_manifest): + data = { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": { + "repo": "policyengine/policyengine-us-data", + "commit": "abc123def456", + }, + "gcs": { + "bucket": "policyengine-us-data", + "generations": { + "enhanced_cps_2024.h5": 1710203948123456, + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + }, + }, + } + result = VersionManifest.from_dict(data) + + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + assert result.hf.repo == ("policyengine/policyengine-us-data") + assert ( + result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 + ) + assert result.gcs.bucket == "policyengine-us-data" + + def test_roundtrip(self, sample_manifest): + roundtripped = VersionManifest.from_dict(sample_manifest.to_dict()) + + assert roundtripped.version == sample_manifest.version + assert roundtripped.created_at == (sample_manifest.created_at) + assert roundtripped.hf.repo == sample_manifest.hf.repo + assert roundtripped.hf.commit == (sample_manifest.hf.commit) + assert roundtripped.gcs.bucket == (sample_manifest.gcs.bucket) + assert roundtripped.gcs.generations == ( + sample_manifest.gcs.generations + ) + + def test_without_hf(self, sample_generations): + manifest = VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + data = manifest.to_dict() + assert data["hf"] is None + + roundtripped = VersionManifest.from_dict(data) + assert roundtripped.hf is None + assert roundtripped.gcs.generations == (sample_generations) + + def test_special_operation_omitted_by_default(self, sample_manifest): + data = sample_manifest.to_dict() + assert "special_operation" not in data + assert "roll_back_version" not in data + + def test_special_operation_included_when_set( + self, sample_generations, sample_hf_info + ): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + special_operation="roll-back", + roll_back_version="1.70.1", + ) + data = manifest.to_dict() + assert data["special_operation"] == "roll-back" + assert data["roll_back_version"] == "1.70.1" + + def test_special_operation_roundtrip( + self, sample_generations, sample_hf_info + ): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + special_operation="roll-back", + roll_back_version="1.70.1", + ) + roundtripped = VersionManifest.from_dict(manifest.to_dict()) + assert roundtripped.special_operation == "roll-back" + assert roundtripped.roll_back_version == "1.70.1" + + def test_regular_manifest_has_no_special_operation(self): + data = { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 123}, + }, + } + result = VersionManifest.from_dict(data) + assert result.special_operation is None + assert result.roll_back_version is None + + +# -- VersionRegistry serialization tests --------------------------- + + +class TestVersionRegistrySerialization: + def test_to_dict(self, sample_registry): + result = sample_registry.to_dict() + + assert result["current"] == "1.72.3" + assert len(result["versions"]) == 1 + assert result["versions"][0]["version"] == "1.72.3" + + def test_from_dict(self, sample_manifest): + data = { + "current": "1.72.3", + "versions": [sample_manifest.to_dict()], + } + result = VersionRegistry.from_dict(data) + + assert result.current == "1.72.3" + assert len(result.versions) == 1 + assert result.versions[0].version == "1.72.3" + assert result.versions[0].hf.commit == "abc123def456" + + def test_roundtrip(self, sample_registry): + roundtripped = VersionRegistry.from_dict(sample_registry.to_dict()) + assert roundtripped.current == (sample_registry.current) + assert len(roundtripped.versions) == len(sample_registry.versions) + assert roundtripped.versions[0].version == "1.72.3" + + def test_get_version(self, sample_registry): + result = sample_registry.get_version("1.72.3") + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + + def test_get_version_not_found(self, sample_registry): + with pytest.raises(ValueError, match="not found"): + sample_registry.get_version("9.9.9") + + def test_empty_registry(self): + registry = VersionRegistry() + assert registry.current == "" + assert registry.versions == [] + + data = registry.to_dict() + assert data == {"current": "", "versions": []} + + +# -- build_manifest tests ------------------------------------------ + + +class TestBuildManifest: + def test_structure(self, mock_bucket): + blob_names = [ + "file_a.h5", + "file_b.h5", + "file_c.h5", + ] + mock_bucket.get_blob.side_effect = [ + _make_mock_blob(100), + _make_mock_blob(200), + _make_mock_blob(300), + ] + + result = build_manifest(mock_bucket, "1.72.3", blob_names) + + assert isinstance(result, VersionManifest) + assert result.version == "1.72.3" + assert result.created_at.endswith("Z") + assert result.gcs.generations == { + "file_a.h5": 100, + "file_b.h5": 200, + "file_c.h5": 300, + } + assert result.gcs.bucket == "policyengine-us-data" + assert result.hf is None + + def test_with_subdirectories(self, mock_bucket): + blob_names = [ + "states/AL.h5", + "districts/CA-01.h5", + ] + mock_bucket.get_blob.side_effect = [ + _make_mock_blob(111), + _make_mock_blob(222), + ] + + result = build_manifest(mock_bucket, "1.72.3", blob_names) + + assert "states/AL.h5" in result.gcs.generations + assert "districts/CA-01.h5" in result.gcs.generations + assert result.gcs.generations["states/AL.h5"] == 111 + assert result.gcs.generations["districts/CA-01.h5"] == 222 + + def test_with_hf_info(self, mock_bucket, sample_hf_info): + mock_bucket.get_blob.return_value = _make_mock_blob(999) + + result = build_manifest( + mock_bucket, + "1.72.3", + ["file.h5"], + hf_info=sample_hf_info, + ) + + assert result.hf is not None + assert result.hf.commit == "abc123def456" + assert result.hf.repo == ("policyengine/policyengine-us-data") + + def test_missing_blob_raises(self, mock_bucket): + mock_bucket.get_blob.return_value = None + + with pytest.raises(ValueError, match="not found"): + build_manifest(mock_bucket, "1.72.3", ["missing.h5"]) + + +# -- upload_manifest tests ----------------------------------------- + + +class TestUploadManifest: + def _setup_empty_registry(self, bucket): + """Mock bucket with no existing registry.""" + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + # First call reads existing registry (NotFound), + # subsequent calls are for writing + written = {} + + def mock_blob(name): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.side_effect = NotFound("Not found") + written[name] = b + return b + b = MagicMock() + b.name = name + written[name] = b + return b + + bucket.blob.side_effect = mock_blob + return written + + def test_writes_registry_to_gcs(self, mock_bucket, sample_manifest): + written = self._setup_empty_registry(mock_bucket) + + upload_manifest(mock_bucket, sample_manifest) + + assert "version_manifest.json" in written + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 1 + assert registry_data["versions"][0]["version"] == "1.72.3" + + def test_includes_hf_commit(self, mock_bucket, sample_manifest): + written = self._setup_empty_registry(mock_bucket) + + upload_manifest(mock_bucket, sample_manifest) + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["versions"][0]["hf"]["commit"] == "abc123def456" + + def test_appends_to_existing_registry(self, mock_bucket, sample_manifest): + # Pre-populate with an older version + older = VersionManifest( + version="1.72.2", + created_at="2026-03-09T10:00:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations={"old.h5": 111}, + ), + ) + existing_registry = VersionRegistry(current="1.72.2", versions=[older]) + existing_json = json.dumps(existing_registry.to_dict()) + written = {} + + def mock_blob(name): + b = MagicMock() + b.name = name + b.download_as_text.return_value = existing_json + written[name] = b + return b + + mock_bucket.blob.side_effect = mock_blob + + upload_manifest(mock_bucket, sample_manifest) + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 2 + # Most recent first + assert registry_data["versions"][0]["version"] == "1.72.3" + assert registry_data["versions"][1]["version"] == "1.72.2" + + @patch("policyengine_us_data.utils.gcs_version.os") + @patch("policyengine_us_data.utils.gcs_version.HfApi") + def test_uploads_to_hf_when_repo_provided( + self, + mock_hf_api_cls, + mock_os, + mock_bucket, + sample_manifest, + ): + mock_os.environ.get.return_value = "fake_token" + mock_os.unlink = MagicMock() + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + + # Mock GCS read (empty registry) + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + upload_manifest( + mock_bucket, + sample_manifest, + hf_repo_name=("policyengine/policyengine-us-data"), + ) + + mock_api.upload_file.assert_called_once() + call_kwargs = mock_api.upload_file.call_args.kwargs + assert call_kwargs["path_in_repo"] == ("version_manifest.json") + assert call_kwargs["repo_id"] == ("policyengine/policyengine-us-data") + + def test_skips_hf_when_no_repo(self, mock_bucket, sample_manifest): + self._setup_empty_registry(mock_bucket) + + # No hf_repo_name — should not raise or call HF + upload_manifest(mock_bucket, sample_manifest) + + +# -- get_current_version tests ------------------------------------- + + +class TestGetCurrentVersion: + def test_returns_version(self, mock_bucket, sample_registry): + _setup_bucket_with_registry(mock_bucket, sample_registry) + + result = get_current_version(mock_bucket) + + assert result == "1.72.3" + mock_bucket.blob.assert_called_with("version_manifest.json") + + def test_no_registry_returns_none(self, mock_bucket): + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + result = get_current_version(mock_bucket) + + assert result is None + + +# -- get_manifest tests --------------------------------------------- + + +class TestGetManifest: + def test_specific_version(self, mock_bucket, sample_registry): + _setup_bucket_with_registry(mock_bucket, sample_registry) + + result = get_manifest(mock_bucket, "1.72.3") + + assert isinstance(result, VersionManifest) + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + assert ( + result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 + ) + + def test_nonexistent_version(self, mock_bucket, sample_registry): + _setup_bucket_with_registry(mock_bucket, sample_registry) + + with pytest.raises(ValueError, match="not found"): + get_manifest(mock_bucket, "9.9.9") + + def test_no_registry_raises(self, mock_bucket): + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + with pytest.raises(ValueError, match="not found"): + get_manifest(mock_bucket, "1.72.3") + + +# -- list_versions tests ------------------------------------------- + + +class TestListVersions: + def test_returns_sorted(self, mock_bucket): + v1 = VersionManifest( + version="1.72.1", + created_at="t1", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 1}), + ) + v2 = VersionManifest( + version="1.72.3", + created_at="t2", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 2}), + ) + v3 = VersionManifest( + version="1.72.2", + created_at="t3", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 3}), + ) + registry = VersionRegistry(current="1.72.3", versions=[v2, v3, v1]) + _setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions(mock_bucket) + + assert result == ["1.72.1", "1.72.2", "1.72.3"] + + def test_empty(self, mock_bucket): + registry = VersionRegistry() + _setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions(mock_bucket) + + assert result == [] + + +# -- download_versioned_file tests --------------------------------- + + +class TestDownloadVersionedFile: + def test_downloads_correct_generation( + self, mock_bucket, sample_manifest, tmp_path + ): + registry = VersionRegistry( + current="1.72.3", versions=[sample_manifest] + ) + registry_json = json.dumps(registry.to_dict()) + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + blob = MagicMock() + blob.download_as_text.return_value = registry_json + return blob + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + local_path = str(tmp_path / "AL.h5") + download_versioned_file( + mock_bucket, + "states/AL.h5", + "1.72.3", + local_path, + ) + + calls = mock_bucket.blob.call_args_list + gen_call = [ + c + for c in calls + if c + == call( + "states/AL.h5", + generation=1710203948345678, + ) + ] + assert len(gen_call) == 1 + + def test_file_not_in_manifest( + self, mock_bucket, sample_manifest, tmp_path + ): + registry = VersionRegistry( + current="1.72.3", versions=[sample_manifest] + ) + _setup_bucket_with_registry(mock_bucket, registry) + + with pytest.raises(ValueError, match="not found"): + download_versioned_file( + mock_bucket, + "nonexistent.h5", + "1.72.3", + str(tmp_path / "out.h5"), + ) + + +# -- rollback tests ------------------------------------------------- + + +class TestRollback: + @patch("policyengine_us_data.utils.gcs_version." "CommitOperationAdd") + @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") + @patch("policyengine_us_data.utils.gcs_version.HfApi") + @patch("policyengine_us_data.utils.gcs_version.os") + def test_creates_new_version_with_old_data( + self, + mock_os, + mock_hf_api_cls, + mock_hf_download, + mock_commit_op, + mock_bucket, + sample_manifest, + ): + mock_os.environ.get.return_value = "fake_token" + mock_os.path.join = lambda *args: "/".join(args) + mock_os.unlink = MagicMock() + + # Setup HF mock + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + commit_info = MagicMock() + commit_info.oid = "new_commit_sha" + mock_api.create_commit.return_value = commit_info + + # Setup bucket: get_manifest reads registry, + # upload_manifest reads then writes registry + registry = VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + registry_json = json.dumps(registry.to_dict()) + written = {} + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.return_value = registry_json + written[name] = b + return b + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + # get_blob returns blobs with new generations + new_gen_counter = iter([50001, 50002, 50003]) + + def mock_get_blob(name): + blob = MagicMock() + blob.generation = next(new_gen_counter) + return blob + + mock_bucket.get_blob.side_effect = mock_get_blob + + result = rollback( + mock_bucket, + target_version="1.72.3", + new_version="1.73.0", + ) + + assert isinstance(result, VersionManifest) + assert result.version == "1.73.0" + assert result.special_operation == "roll-back" + assert result.roll_back_version == "1.72.3" + + # GCS files were copied + assert mock_bucket.copy_blob.call_count == 3 + + # Registry was written with both versions + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.73.0" + assert len(registry_data["versions"]) == 2 + assert registry_data["versions"][0]["version"] == "1.73.0" + assert registry_data["versions"][0]["special_operation"] == "roll-back" + + # HF commit was created and tagged + mock_api.create_commit.assert_called_once() + commit_msg = mock_api.create_commit.call_args.kwargs["commit_message"] + assert "1.72.3" in commit_msg + assert "1.73.0" in commit_msg + mock_api.create_tag.assert_called_once() + + def test_nonexistent_version(self, mock_bucket): + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + with pytest.raises(ValueError, match="not found"): + rollback( + mock_bucket, + target_version="9.9.9", + new_version="9.10.0", + ) + + +# -- upload_files_to_gcs return value test -------------------------- + + +class TestUploadFilesToGcsReturnsGenerations: + @patch("policyengine_us_data.utils.data_upload." "google.auth") + @patch("policyengine_us_data.utils.data_upload.storage") + def test_returns_generations(self, mock_storage, mock_auth, tmp_path): + from policyengine_us_data.utils.data_upload import ( + upload_files_to_gcs, + ) + + mock_auth.default.return_value = ( + MagicMock(), + "project", + ) + mock_client = MagicMock() + mock_storage.Client.return_value = mock_client + mock_bucket = MagicMock() + mock_client.bucket.return_value = mock_bucket + + file_a = tmp_path / "file_a.h5" + file_a.write_bytes(b"data_a") + file_b = tmp_path / "file_b.h5" + file_b.write_bytes(b"data_b") + + blob_a = MagicMock() + blob_a.generation = 99999 + blob_b = MagicMock() + blob_b.generation = 88888 + + mock_bucket.blob.side_effect = [blob_a, blob_b] + + result = upload_files_to_gcs( + files=[str(file_a), str(file_b)], + version="1.72.3", + ) + + assert result == { + "file_a.h5": 99999, + "file_b.h5": 88888, + } + + +# -- End-to-end upload creates registry test ------------------------ + + +class TestEndToEndUploadCreatesRegistry: + @patch("policyengine_us_data.utils.data_upload." "google.auth") + @patch("policyengine_us_data.utils.data_upload.storage") + @patch("policyengine_us_data.utils.data_upload.HfApi") + @patch("policyengine_us_data.utils.data_upload.os") + def test_creates_registry( + self, + mock_os, + mock_hf_api_cls, + mock_storage, + mock_auth, + tmp_path, + ): + from policyengine_us_data.utils.data_upload import ( + upload_data_files, + ) + + mock_auth.default.return_value = ( + MagicMock(), + "project", + ) + mock_os.environ.get.return_value = "fake_token" + + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + commit_info = MagicMock() + commit_info.oid = "abc123" + mock_api.create_commit.return_value = commit_info + + mock_client = MagicMock() + mock_storage.Client.return_value = mock_client + mock_bucket = MagicMock() + mock_bucket.name = "policyengine-us-data" + mock_client.bucket.return_value = mock_bucket + + blob_data = MagicMock() + blob_data.generation = 55555 + + written = {} + + def mock_blob(name): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.side_effect = NotFound("Not found") + written[name] = b + return b + return blob_data + + mock_bucket.blob.side_effect = mock_blob + + test_file = tmp_path / "test.h5" + test_file.write_bytes(b"test_data") + + upload_data_files( + files=[str(test_file)], + version="1.72.3", + ) + + assert "version_manifest.json" in written + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 1 + assert registry_data["versions"][0]["hf"]["commit"] == "abc123" + assert "test.h5" in registry_data["versions"][0]["gcs"]["generations"] + + +# -- Consumer API tests -------------------------------------------- + + +class TestGetDataManifest: + def setup_method(self): + import policyengine_us_data.utils.gcs_version as mod + + mod._cached_registry = None + + def teardown_method(self): + import policyengine_us_data.utils.gcs_version as mod + + mod._cached_registry = None + + @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") + def test_returns_registry(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": { + "repo": ("policyengine/" "policyengine-us-data"), + "commit": "abc123", + }, + "gcs": { + "bucket": "policyengine-us-data", + "generations": {"file.h5": 12345}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + result = get_data_manifest() + + assert isinstance(result, VersionRegistry) + assert result.current == "1.72.3" + assert len(result.versions) == 1 + assert result.versions[0].hf.commit == "abc123" + mock_download.assert_called_once_with( + repo_id="policyengine/policyengine-us-data", + repo_type="model", + filename="version_manifest.json", + ) + + @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") + def test_caches_result(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 1}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + first = get_data_manifest() + second = get_data_manifest() + + assert first is second + assert mock_download.call_count == 1 + + @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") + def test_get_data_version(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 1}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + result = get_data_version() + + assert result == "1.72.3" diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index 7b7481b3e..6ff7c233c 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -33,33 +33,70 @@ def upload_data_files( gcs_bucket_name: str = "policyengine-us-data", hf_repo_name: str = "policyengine/policyengine-us-data", hf_repo_type: str = "model", - version: str = None, -): + version: Optional[str] = None, +) -> None: + from policyengine_us_data.utils.gcs_version import ( + GCSVersionInfo, + HFVersionInfo, + VersionManifest, + upload_manifest, + ) + from datetime import datetime, timezone + if version is None: version = metadata.version("policyengine-us-data") - upload_files_to_hf( + hf_commit = upload_files_to_hf( files=files, version=version, hf_repo_name=hf_repo_name, hf_repo_type=hf_repo_type, ) - upload_files_to_gcs( + generations = upload_files_to_gcs( files=files, version=version, gcs_bucket_name=gcs_bucket_name, ) + # Build and upload version manifest + credentials, project_id = google.auth.default() + storage_client = storage.Client( + credentials=credentials, project=project_id + ) + bucket = storage_client.bucket(gcs_bucket_name) + + manifest = VersionManifest( + version=version, + created_at=datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z"), + hf=HFVersionInfo(repo=hf_repo_name, commit=hf_commit), + gcs=GCSVersionInfo( + bucket=gcs_bucket_name, + generations=generations, + ), + ) + upload_manifest( + bucket, + manifest, + hf_repo_name=hf_repo_name, + hf_repo_type=hf_repo_type, + ) + logging.info(f"Created version manifest for {version}.") + def upload_files_to_hf( files: List[str], version: str, hf_repo_name: str = "policyengine/policyengine-us-data", hf_repo_type: str = "model", -): - """ - Upload files to Hugging Face repository and tag the commit with the version. +) -> str: + """Upload files to Hugging Face repository and tag the + commit with the version. + + Returns: + The commit SHA (oid) of the created commit. """ api = HfApi() hf_operations = [] @@ -86,7 +123,7 @@ def upload_files_to_hf( ) logging.info(f"Uploaded files to Hugging Face repository {hf_repo_name}.") - # Tag commit with version + # Tag commit with version (convenience for HF web UI) try: api.create_tag( token=token, @@ -106,32 +143,50 @@ def upload_files_to_hf( else: raise + return commit_info.oid + def upload_files_to_gcs( files: List[str], version: str, gcs_bucket_name: str = "policyengine-us-data", -): - """ - Upload files to Google Cloud Storage and set metadata with the version. +) -> Dict[str, int]: + """Upload files to Google Cloud Storage and set metadata + with the version. + + Returns: + Dict mapping blob name to its GCS generation number. """ credentials, project_id = google.auth.default() - storage_client = storage.Client(credentials=credentials, project=project_id) + storage_client = storage.Client( + credentials=credentials, project=project_id + ) bucket = storage_client.bucket(gcs_bucket_name) + generations: Dict[str, int] = {} for file_path in files: file_path = Path(file_path) blob = bucket.blob(file_path.name) blob.upload_from_filename(file_path) - logging.info(f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}.") + logging.info( + f"Uploaded {file_path.name} to GCS bucket " f"{gcs_bucket_name}." + ) # Set metadata blob.metadata = {"version": version} blob.patch() + + # Read back generation number for manifest + blob.reload() + generations[file_path.name] = blob.generation logging.info( - f"Set metadata for {file_path.name} in GCS bucket {gcs_bucket_name}." + f"Set metadata for {file_path.name} in GCS " + f"bucket {gcs_bucket_name} " + f"(generation {blob.generation})." ) + return generations + def upload_local_area_file( file_path: str, @@ -141,15 +196,19 @@ def upload_local_area_file( hf_repo_type: str = "model", version: str = None, skip_hf: bool = False, -): - """ - Upload a single local area H5 file to a subdirectory (states/ or districts/). +) -> int: + """Upload a single local area H5 file to a subdirectory + (states/ or districts/). - Uploads to both GCS and Hugging Face with the file placed in the specified - subdirectory. + Uploads to both GCS and Hugging Face with the file placed + in the specified subdirectory. Args: - skip_hf: If True, skip HuggingFace upload (for batched uploads later) + skip_hf: If True, skip HuggingFace upload (for batched + uploads later) + + Returns: + The GCS generation number of the uploaded blob. """ if version is None: version = metadata.version("policyengine-us-data") @@ -160,7 +219,9 @@ def upload_local_area_file( # Upload to GCS with subdirectory credentials, project_id = google.auth.default() - storage_client = storage.Client(credentials=credentials, project=project_id) + storage_client = storage.Client( + credentials=credentials, project=project_id + ) bucket = storage_client.bucket(gcs_bucket_name) blob_name = f"{subdirectory}/{file_path.name}" @@ -168,10 +229,15 @@ def upload_local_area_file( blob.upload_from_filename(file_path) blob.metadata = {"version": version} blob.patch() - logging.info(f"Uploaded {blob_name} to GCS bucket {gcs_bucket_name}.") + blob.reload() + generation = blob.generation + logging.info( + f"Uploaded {blob_name} to GCS bucket " + f"{gcs_bucket_name} (generation {generation})." + ) if skip_hf: - return + return generation # Upload to Hugging Face with subdirectory token = os.environ.get("HUGGING_FACE_TOKEN") @@ -182,12 +248,17 @@ def upload_local_area_file( repo_id=hf_repo_name, repo_type=hf_repo_type, token=token, - commit_message=f"Upload {subdirectory}/{file_path.name} for version {version}", + commit_message=( + f"Upload {subdirectory}/{file_path.name} " f"for version {version}" + ), ) logging.info( - f"Uploaded {subdirectory}/{file_path.name} to Hugging Face {hf_repo_name}." + f"Uploaded {subdirectory}/{file_path.name} to " + f"Hugging Face {hf_repo_name}." ) + return generation + def upload_local_area_batch_to_hf( files_with_subdirs: List[tuple], @@ -330,7 +401,9 @@ def upload_to_staging_hf( f"Uploaded batch {i // batch_size + 1}: {len(operations)} files to staging/" ) - logging.info(f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace") + logging.info( + f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace" + ) return total_uploaded diff --git a/policyengine_us_data/utils/gcs_version.py b/policyengine_us_data/utils/gcs_version.py new file mode 100644 index 000000000..c49aaed01 --- /dev/null +++ b/policyengine_us_data/utils/gcs_version.py @@ -0,0 +1,574 @@ +""" +GCS version registry for semver-based dataset versioning. + +Provides typed structures and functions for versioned uploads, +downloads, and rollbacks across GCS and Hugging Face. All +versions are tracked in a single registry file +(version_manifest.json) on both backends. +""" + +import json +import logging +import os +import tempfile +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +from google.api_core.exceptions import NotFound +from google.cloud import storage +from huggingface_hub import ( + HfApi, + CommitOperationAdd, + hf_hub_download, +) + +REGISTRY_BLOB = "version_manifest.json" + + +@dataclass +class HFVersionInfo: + """Hugging Face backend location for a version.""" + + repo: str + commit: str + + def to_dict(self) -> dict[str, str]: + return {"repo": self.repo, "commit": self.commit} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "HFVersionInfo": + return cls(repo=data["repo"], commit=data["commit"]) + + +@dataclass +class GCSVersionInfo: + """GCS backend location for a version.""" + + bucket: str + generations: dict[str, int] + + def to_dict(self) -> dict[str, Any]: + return { + "bucket": self.bucket, + "generations": self.generations, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "GCSVersionInfo": + return cls( + bucket=data["bucket"], + generations=data["generations"], + ) + + +@dataclass +class VersionManifest: + """Single version entry tying semver to backend + identifiers. + + Consumers interact only with the semver version string. + HF commit SHAs and GCS generation numbers are internal + implementation details resolved by this manifest. + """ + + version: str + created_at: str + hf: Optional[HFVersionInfo] + gcs: GCSVersionInfo + special_operation: Optional[str] = None + roll_back_version: Optional[str] = None + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "version": self.version, + "created_at": self.created_at, + "hf": self.hf.to_dict() if self.hf else None, + "gcs": self.gcs.to_dict(), + } + if self.special_operation is not None: + result["special_operation"] = self.special_operation + if self.roll_back_version is not None: + result["roll_back_version"] = self.roll_back_version + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VersionManifest": + hf_data = data.get("hf") + return cls( + version=data["version"], + created_at=data["created_at"], + hf=(HFVersionInfo.from_dict(hf_data) if hf_data else None), + gcs=GCSVersionInfo.from_dict(data["gcs"]), + special_operation=data.get("special_operation"), + roll_back_version=data.get("roll_back_version"), + ) + + +@dataclass +class VersionRegistry: + """Registry of all dataset versions. + + Contains a pointer to the current version and a list of + all version manifests (most recent first). + """ + + current: str = "" + versions: list[VersionManifest] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "current": self.current, + "versions": [v.to_dict() for v in self.versions], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VersionRegistry": + return cls( + current=data["current"], + versions=[VersionManifest.from_dict(v) for v in data["versions"]], + ) + + def get_version(self, version: str) -> VersionManifest: + """Look up a specific version entry. + + Args: + version: Semver version string. + + Returns: + The matching VersionManifest. + + Raises: + ValueError: If the version is not in the registry. + """ + for v in self.versions: + if v.version == version: + return v + available = [v.version for v in self.versions[:10]] + raise ValueError( + f"Version '{version}' not found in registry. " + f"Available versions: {available}" + ) + + +def build_manifest( + bucket: storage.Bucket, + version: str, + blob_names: list[str], + hf_info: Optional[HFVersionInfo] = None, +) -> VersionManifest: + """Build a version manifest by reading generation numbers + from uploaded blobs. + + Args: + bucket: GCS bucket containing the uploaded blobs. + version: Semver version string. + blob_names: List of blob paths to include in the + manifest. + hf_info: Optional HF backend info to include. + + Returns: + A VersionManifest with generation numbers for each blob. + """ + generations: dict[str, int] = {} + for name in blob_names: + blob = bucket.get_blob(name) + if blob is None: + raise ValueError( + f"Blob '{name}' not found in bucket " + f"'{bucket.name}' after upload." + ) + generations[name] = blob.generation + + return VersionManifest( + version=version, + created_at=datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z"), + hf=hf_info, + gcs=GCSVersionInfo( + bucket=bucket.name, + generations=generations, + ), + ) + + +# -- Registry I/O ------------------------------------------------- + + +def _read_registry_from_gcs( + bucket: storage.Bucket, +) -> VersionRegistry: + """Read the version registry from GCS. + + Returns an empty registry if no registry exists yet. + """ + blob = bucket.blob(REGISTRY_BLOB) + try: + content = blob.download_as_text() + except NotFound: + return VersionRegistry() + return VersionRegistry.from_dict(json.loads(content)) + + +def _upload_registry_to_gcs( + bucket: storage.Bucket, + registry: VersionRegistry, +) -> None: + """Write the version registry to GCS.""" + data = json.dumps(registry.to_dict(), indent=2) + blob = bucket.blob(REGISTRY_BLOB) + blob.upload_from_string(data, content_type="application/json") + logging.info(f"Uploaded registry to GCS " f"(current={registry.current}).") + + +def _upload_registry_to_hf( + registry: VersionRegistry, + hf_repo_name: str, + hf_repo_type: str, +) -> None: + """Write the version registry to Hugging Face.""" + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + data = json.dumps(registry.to_dict(), indent=2) + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as f: + f.write(data) + tmp_path = f.name + + try: + api.upload_file( + path_or_fileobj=tmp_path, + path_in_repo=REGISTRY_BLOB, + repo_id=hf_repo_name, + repo_type=hf_repo_type, + token=token, + commit_message=( + f"Update version registry " f"(current={registry.current})" + ), + ) + logging.info( + f"Uploaded {REGISTRY_BLOB} to " f"HF repo {hf_repo_name}." + ) + finally: + os.unlink(tmp_path) + + +def upload_manifest( + bucket: storage.Bucket, + manifest: VersionManifest, + hf_repo_name: Optional[str] = None, + hf_repo_type: str = "model", +) -> None: + """Append a version manifest to the registry and upload. + + Reads the existing registry from GCS (or starts fresh), + prepends the new manifest, updates the current pointer, + and writes the registry to GCS and optionally HF. + + Args: + bucket: GCS bucket to upload to. + manifest: The version manifest to add. + hf_repo_name: If provided, also upload to this + HF repo. + hf_repo_type: HF repository type. + """ + registry = _read_registry_from_gcs(bucket) + registry.versions.insert(0, manifest) + registry.current = manifest.version + + _upload_registry_to_gcs(bucket, registry) + + if hf_repo_name is not None: + _upload_registry_to_hf(registry, hf_repo_name, hf_repo_type) + + +# -- Query functions ----------------------------------------------- + + +def get_current_version( + bucket: storage.Bucket, +) -> Optional[str]: + """Get the current version of the bucket. + + Args: + bucket: GCS bucket to query. + + Returns: + The current semver version string, or None if no + registry exists. + """ + registry = _read_registry_from_gcs(bucket) + if not registry.current: + return None + return registry.current + + +def get_manifest( + bucket: storage.Bucket, + version: str, +) -> VersionManifest: + """Get the manifest for a specific version. + + Args: + bucket: GCS bucket to query. + version: Semver version string. + + Returns: + The deserialized VersionManifest. + + Raises: + ValueError: If the version is not in the registry. + """ + registry = _read_registry_from_gcs(bucket) + return registry.get_version(version) + + +def list_versions( + bucket: storage.Bucket, +) -> list[str]: + """List all available versions in the bucket. + + Args: + bucket: GCS bucket to query. + + Returns: + Sorted list of semver version strings. + """ + registry = _read_registry_from_gcs(bucket) + return sorted(v.version for v in registry.versions) + + +def download_versioned_file( + bucket: storage.Bucket, + file_path: str, + version: str, + local_path: str, +) -> str: + """Download a specific file at a specific version. + + Args: + bucket: GCS bucket to download from. + file_path: Path of the file within the bucket. + version: Semver version string. + local_path: Local path to save the file to. + + Returns: + The local path where the file was saved. + + Raises: + ValueError: If the version or file is not found. + """ + manifest = get_manifest(bucket, version) + + if file_path not in manifest.gcs.generations: + raise ValueError( + f"File '{file_path}' not found in manifest for " + f"version '{version}'. Available files: " + f"{list(manifest.gcs.generations.keys())[:10]}..." + ) + + generation = manifest.gcs.generations[file_path] + blob = bucket.blob(file_path, generation=generation) + + Path(local_path).parent.mkdir(parents=True, exist_ok=True) + blob.download_to_filename(local_path) + + logging.info( + f"Downloaded {file_path} at version {version} " + f"(generation {generation}) to {local_path}." + ) + return local_path + + +def rollback( + bucket: storage.Bucket, + target_version: str, + new_version: str, + hf_repo_name: str = "policyengine/policyengine-us-data", + hf_repo_type: str = "model", +) -> VersionManifest: + """Roll back by releasing a new version with old data. + + This treats rollback as a new release: data from + target_version is copied to the live paths (creating new + GCS generations), a new HF commit is created with the + old data, and a new manifest is published under + new_version with special_operation="roll-back". + + Args: + bucket: GCS bucket to roll back. + target_version: Semver version to roll back to. + new_version: New semver version to publish + (e.g., "1.73.0"). + hf_repo_name: HuggingFace repository name. + hf_repo_type: HuggingFace repository type. + + Returns: + The new VersionManifest for the rollback release. + + Raises: + ValueError: If target_version is not in the registry. + """ + old_manifest = get_manifest(bucket, target_version) + + # 1. Restore GCS files by copying old generations + # to live paths, then record new generations. + new_generations: dict[str, int] = {} + for file_path, generation in old_manifest.gcs.generations.items(): + source_blob = bucket.blob(file_path, generation=generation) + bucket.copy_blob(source_blob, bucket, file_path) + # Read back the new generation + restored_blob = bucket.get_blob(file_path) + new_generations[file_path] = restored_blob.generation + logging.info( + f"Restored {file_path}: generation " + f"{generation} -> {restored_blob.generation}." + ) + + # 2. Re-upload old data to HF as a new commit + hf_commit = None + if old_manifest.hf is not None: + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + + operations = [] + with tempfile.TemporaryDirectory() as tmpdir: + for file_path in old_manifest.gcs.generations.keys(): + local_path = os.path.join(tmpdir, file_path.replace("/", "_")) + hf_hub_download( + repo_id=old_manifest.hf.repo, + repo_type=hf_repo_type, + filename=file_path, + revision=old_manifest.hf.commit, + local_dir=tmpdir, + token=token, + ) + downloaded = os.path.join(tmpdir, file_path) + operations.append( + CommitOperationAdd( + path_in_repo=file_path, + path_or_fileobj=downloaded, + ) + ) + + commit_info = api.create_commit( + token=token, + repo_id=hf_repo_name, + operations=operations, + repo_type=hf_repo_type, + commit_message=( + f"Roll back to {target_version} " f"as {new_version}" + ), + ) + hf_commit = commit_info.oid + + # Tag the new commit + try: + api.create_tag( + token=token, + repo_id=hf_repo_name, + tag=new_version, + revision=hf_commit, + repo_type=hf_repo_type, + ) + except Exception as e: + if "already exists" in str(e) or "409" in str(e): + logging.warning( + f"Tag {new_version} already exists. " + f"Skipping tag creation." + ) + else: + raise + + # 3. Build and upload the new manifest + new_manifest = VersionManifest( + version=new_version, + created_at=datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z"), + hf=( + HFVersionInfo(repo=hf_repo_name, commit=hf_commit) + if hf_commit + else None + ), + gcs=GCSVersionInfo( + bucket=bucket.name, + generations=new_generations, + ), + special_operation="roll-back", + roll_back_version=target_version, + ) + upload_manifest( + bucket, + new_manifest, + hf_repo_name=hf_repo_name, + hf_repo_type=hf_repo_type, + ) + + logging.info( + f"Rolled back to {target_version} as new " + f"version {new_version}. " + f"Restored {len(new_generations)} files." + ) + return new_manifest + + +# -- Consumer API -------------------------------------------------- + + +_cached_registry: Optional[VersionRegistry] = None + + +def get_data_manifest( + hf_repo_name: str = "policyengine/policyengine-us-data", + hf_repo_type: str = "model", +) -> VersionRegistry: + """Get the full version registry from HF. + + Fetches version_manifest.json from the Hugging Face repo + and returns it as a VersionRegistry. The result is cached + in memory after the first call. + + Args: + hf_repo_name: HF repository name. + hf_repo_type: HF repository type. + + Returns: + The full VersionRegistry. + """ + global _cached_registry + if _cached_registry is not None: + return _cached_registry + + local_path = hf_hub_download( + repo_id=hf_repo_name, + repo_type=hf_repo_type, + filename=REGISTRY_BLOB, + ) + with open(local_path) as f: + data = json.load(f) + + _cached_registry = VersionRegistry.from_dict(data) + return _cached_registry + + +def get_data_version( + hf_repo_name: str = "policyengine/policyengine-us-data", + hf_repo_type: str = "model", +) -> str: + """Get the current deployed data version string. + + Convenience wrapper around get_data_manifest(). + + Args: + hf_repo_name: HF repository name. + hf_repo_type: HF repository type. + + Returns: + The current semver version string. + """ + return get_data_manifest(hf_repo_name, hf_repo_type).current From 9a6f843ef1a82745c773c3f91390360f260fe381 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Mar 2026 18:56:51 +0100 Subject: [PATCH 2/8] Extract test fixtures into shared fixtures module Co-Authored-By: Claude Opus 4.6 --- .../tests/fixtures/__init__.py | 0 .../tests/fixtures/test_gcs_version.py | 81 ++++++++++++++++ .../tests/test_gcs_version.py | 96 ++++--------------- 3 files changed, 102 insertions(+), 75 deletions(-) create mode 100644 policyengine_us_data/tests/fixtures/__init__.py create mode 100644 policyengine_us_data/tests/fixtures/test_gcs_version.py diff --git a/policyengine_us_data/tests/fixtures/__init__.py b/policyengine_us_data/tests/fixtures/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/policyengine_us_data/tests/fixtures/test_gcs_version.py b/policyengine_us_data/tests/fixtures/test_gcs_version.py new file mode 100644 index 000000000..d41ae55f5 --- /dev/null +++ b/policyengine_us_data/tests/fixtures/test_gcs_version.py @@ -0,0 +1,81 @@ +"""Shared fixtures for GCS version registry tests.""" + +import json +from unittest.mock import MagicMock + +import pytest + +from policyengine_us_data.utils.gcs_version import ( + HFVersionInfo, + GCSVersionInfo, + VersionManifest, + VersionRegistry, +) + + +@pytest.fixture +def sample_generations() -> dict[str, int]: + return { + "enhanced_cps_2024.h5": 1710203948123456, + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + } + + +@pytest.fixture +def sample_hf_info() -> HFVersionInfo: + return HFVersionInfo( + repo="policyengine/policyengine-us-data", + commit="abc123def456", + ) + + +@pytest.fixture +def sample_manifest( + sample_generations: dict[str, int], + sample_hf_info: HFVersionInfo, +) -> VersionManifest: + return VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + + +@pytest.fixture +def sample_registry( + sample_manifest: VersionManifest, +) -> VersionRegistry: + """A registry with one version entry.""" + return VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + + +@pytest.fixture +def mock_bucket() -> MagicMock: + bucket = MagicMock() + bucket.name = "policyengine-us-data" + return bucket + + +def make_mock_blob(generation: int) -> MagicMock: + blob = MagicMock() + blob.generation = generation + return blob + + +def setup_bucket_with_registry( + bucket: MagicMock, + registry: VersionRegistry, +) -> None: + """Configure a mock bucket to serve a registry.""" + registry_json = json.dumps(registry.to_dict()) + blob = MagicMock() + blob.download_as_text.return_value = registry_json + bucket.blob.return_value = blob diff --git a/policyengine_us_data/tests/test_gcs_version.py b/policyengine_us_data/tests/test_gcs_version.py index 27d91094c..90c2d96a4 100644 --- a/policyengine_us_data/tests/test_gcs_version.py +++ b/policyengine_us_data/tests/test_gcs_version.py @@ -21,69 +21,15 @@ get_data_manifest, get_data_version, ) - -# -- Fixtures ------------------------------------------------------- - - -@pytest.fixture -def sample_generations(): - return { - "enhanced_cps_2024.h5": 1710203948123456, - "cps_2024.h5": 1710203948234567, - "states/AL.h5": 1710203948345678, - } - - -@pytest.fixture -def sample_hf_info(): - return HFVersionInfo( - repo="policyengine/policyengine-us-data", - commit="abc123def456", - ) - - -@pytest.fixture -def sample_manifest(sample_generations, sample_hf_info): - return VersionManifest( - version="1.72.3", - created_at="2026-03-10T14:30:00Z", - hf=sample_hf_info, - gcs=GCSVersionInfo( - bucket="policyengine-us-data", - generations=sample_generations, - ), - ) - - -@pytest.fixture -def sample_registry(sample_manifest): - """A registry with one version entry.""" - return VersionRegistry( - current="1.72.3", - versions=[sample_manifest], - ) - - -@pytest.fixture -def mock_bucket(): - bucket = MagicMock() - bucket.name = "policyengine-us-data" - return bucket - - -def _make_mock_blob(generation: int): - blob = MagicMock() - blob.generation = generation - return blob - - -def _setup_bucket_with_registry(bucket, registry): - """Configure a mock bucket to serve a registry.""" - registry_json = json.dumps(registry.to_dict()) - blob = MagicMock() - blob.download_as_text.return_value = registry_json - bucket.blob.return_value = blob - +from policyengine_us_data.tests.fixtures.test_gcs_version import ( + sample_generations, + sample_hf_info, + sample_manifest, + sample_registry, + mock_bucket, + make_mock_blob, + setup_bucket_with_registry, +) # -- VersionManifest serialization tests --------------------------- @@ -272,9 +218,9 @@ def test_structure(self, mock_bucket): "file_c.h5", ] mock_bucket.get_blob.side_effect = [ - _make_mock_blob(100), - _make_mock_blob(200), - _make_mock_blob(300), + make_mock_blob(100), + make_mock_blob(200), + make_mock_blob(300), ] result = build_manifest(mock_bucket, "1.72.3", blob_names) @@ -296,8 +242,8 @@ def test_with_subdirectories(self, mock_bucket): "districts/CA-01.h5", ] mock_bucket.get_blob.side_effect = [ - _make_mock_blob(111), - _make_mock_blob(222), + make_mock_blob(111), + make_mock_blob(222), ] result = build_manifest(mock_bucket, "1.72.3", blob_names) @@ -308,7 +254,7 @@ def test_with_subdirectories(self, mock_bucket): assert result.gcs.generations["districts/CA-01.h5"] == 222 def test_with_hf_info(self, mock_bucket, sample_hf_info): - mock_bucket.get_blob.return_value = _make_mock_blob(999) + mock_bucket.get_blob.return_value = make_mock_blob(999) result = build_manifest( mock_bucket, @@ -458,7 +404,7 @@ def test_skips_hf_when_no_repo(self, mock_bucket, sample_manifest): class TestGetCurrentVersion: def test_returns_version(self, mock_bucket, sample_registry): - _setup_bucket_with_registry(mock_bucket, sample_registry) + setup_bucket_with_registry(mock_bucket, sample_registry) result = get_current_version(mock_bucket) @@ -480,7 +426,7 @@ def test_no_registry_returns_none(self, mock_bucket): class TestGetManifest: def test_specific_version(self, mock_bucket, sample_registry): - _setup_bucket_with_registry(mock_bucket, sample_registry) + setup_bucket_with_registry(mock_bucket, sample_registry) result = get_manifest(mock_bucket, "1.72.3") @@ -492,7 +438,7 @@ def test_specific_version(self, mock_bucket, sample_registry): ) def test_nonexistent_version(self, mock_bucket, sample_registry): - _setup_bucket_with_registry(mock_bucket, sample_registry) + setup_bucket_with_registry(mock_bucket, sample_registry) with pytest.raises(ValueError, match="not found"): get_manifest(mock_bucket, "9.9.9") @@ -530,7 +476,7 @@ def test_returns_sorted(self, mock_bucket): gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 3}), ) registry = VersionRegistry(current="1.72.3", versions=[v2, v3, v1]) - _setup_bucket_with_registry(mock_bucket, registry) + setup_bucket_with_registry(mock_bucket, registry) result = list_versions(mock_bucket) @@ -538,7 +484,7 @@ def test_returns_sorted(self, mock_bucket): def test_empty(self, mock_bucket): registry = VersionRegistry() - _setup_bucket_with_registry(mock_bucket, registry) + setup_bucket_with_registry(mock_bucket, registry) result = list_versions(mock_bucket) @@ -595,7 +541,7 @@ def test_file_not_in_manifest( registry = VersionRegistry( current="1.72.3", versions=[sample_manifest] ) - _setup_bucket_with_registry(mock_bucket, registry) + setup_bucket_with_registry(mock_bucket, registry) with pytest.raises(ValueError, match="not found"): download_versioned_file( From 017c030dbcd38f2e5d5228b2736627be83332f58 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Mar 2026 20:18:14 +0100 Subject: [PATCH 3/8] Rename gcs_version to version_manifest and refactor Rename the module to reflect that both GCS and HF are equal backends, not just GCS. Add a constants block for all backend config, remove optional HF parameters from public API, and extract step-based procedural blocks into named helper functions. Co-Authored-By: Claude Opus 4.6 --- policyengine_us_data/__init__.py | 2 +- ...cs_version.py => test_version_manifest.py} | 4 +- ...cs_version.py => test_version_manifest.py} | 272 +++++++---- policyengine_us_data/utils/data_upload.py | 28 +- .../{gcs_version.py => version_manifest.py} | 425 +++++++++--------- 5 files changed, 398 insertions(+), 333 deletions(-) rename policyengine_us_data/tests/fixtures/{test_gcs_version.py => test_version_manifest.py} (94%) rename policyengine_us_data/tests/{test_gcs_version.py => test_version_manifest.py} (78%) rename policyengine_us_data/utils/{gcs_version.py => version_manifest.py} (62%) diff --git a/policyengine_us_data/__init__.py b/policyengine_us_data/__init__.py index e2707f030..493ff1b88 100644 --- a/policyengine_us_data/__init__.py +++ b/policyengine_us_data/__init__.py @@ -1,3 +1,3 @@ from .datasets import * from .geography import ZIP_CODE_DATASET -from .utils.gcs_version import get_data_version, get_data_manifest +from .utils.version_manifest import get_data_version, get_data_manifest diff --git a/policyengine_us_data/tests/fixtures/test_gcs_version.py b/policyengine_us_data/tests/fixtures/test_version_manifest.py similarity index 94% rename from policyengine_us_data/tests/fixtures/test_gcs_version.py rename to policyengine_us_data/tests/fixtures/test_version_manifest.py index d41ae55f5..a6158dc76 100644 --- a/policyengine_us_data/tests/fixtures/test_gcs_version.py +++ b/policyengine_us_data/tests/fixtures/test_version_manifest.py @@ -1,11 +1,11 @@ -"""Shared fixtures for GCS version registry tests.""" +"""Shared fixtures for version manifest tests.""" import json from unittest.mock import MagicMock import pytest -from policyengine_us_data.utils.gcs_version import ( +from policyengine_us_data.utils.version_manifest import ( HFVersionInfo, GCSVersionInfo, VersionManifest, diff --git a/policyengine_us_data/tests/test_gcs_version.py b/policyengine_us_data/tests/test_version_manifest.py similarity index 78% rename from policyengine_us_data/tests/test_gcs_version.py rename to policyengine_us_data/tests/test_version_manifest.py index 90c2d96a4..c6c6b5058 100644 --- a/policyengine_us_data/tests/test_gcs_version.py +++ b/policyengine_us_data/tests/test_version_manifest.py @@ -1,4 +1,4 @@ -"""Tests for GCS version registry system.""" +"""Tests for version manifest registry system.""" import json from unittest.mock import MagicMock, patch, call @@ -6,7 +6,7 @@ import pytest from google.api_core.exceptions import NotFound -from policyengine_us_data.utils.gcs_version import ( +from policyengine_us_data.utils.version_manifest import ( HFVersionInfo, GCSVersionInfo, VersionManifest, @@ -21,7 +21,7 @@ get_data_manifest, get_data_version, ) -from policyengine_us_data.tests.fixtures.test_gcs_version import ( +from policyengine_us_data.tests.fixtures.test_version_manifest import ( sample_generations, sample_hf_info, sample_manifest, @@ -31,6 +31,9 @@ setup_bucket_with_registry, ) +_MOD = "policyengine_us_data.utils.version_manifest" + + # -- VersionManifest serialization tests --------------------------- @@ -53,13 +56,13 @@ def test_from_dict(self, sample_manifest): "version": "1.72.3", "created_at": "2026-03-10T14:30:00Z", "hf": { - "repo": "policyengine/policyengine-us-data", + "repo": ("policyengine/policyengine-us-data"), "commit": "abc123def456", }, "gcs": { "bucket": "policyengine-us-data", "generations": { - "enhanced_cps_2024.h5": 1710203948123456, + "enhanced_cps_2024.h5": (1710203948123456), "cps_2024.h5": 1710203948234567, "states/AL.h5": 1710203948345678, }, @@ -79,13 +82,11 @@ def test_roundtrip(self, sample_manifest): roundtripped = VersionManifest.from_dict(sample_manifest.to_dict()) assert roundtripped.version == sample_manifest.version - assert roundtripped.created_at == (sample_manifest.created_at) + assert roundtripped.created_at == sample_manifest.created_at assert roundtripped.hf.repo == sample_manifest.hf.repo - assert roundtripped.hf.commit == (sample_manifest.hf.commit) - assert roundtripped.gcs.bucket == (sample_manifest.gcs.bucket) - assert roundtripped.gcs.generations == ( - sample_manifest.gcs.generations - ) + assert roundtripped.hf.commit == sample_manifest.hf.commit + assert roundtripped.gcs.bucket == sample_manifest.gcs.bucket + assert roundtripped.gcs.generations == sample_manifest.gcs.generations def test_without_hf(self, sample_generations): manifest = VersionManifest( @@ -145,7 +146,9 @@ def test_special_operation_roundtrip( assert roundtripped.special_operation == "roll-back" assert roundtripped.roll_back_version == "1.70.1" - def test_regular_manifest_has_no_special_operation(self): + def test_regular_manifest_has_no_special_operation( + self, + ): data = { "version": "1.72.3", "created_at": "2026-03-10T14:30:00Z", @@ -185,7 +188,7 @@ def test_from_dict(self, sample_manifest): def test_roundtrip(self, sample_registry): roundtripped = VersionRegistry.from_dict(sample_registry.to_dict()) - assert roundtripped.current == (sample_registry.current) + assert roundtripped.current == sample_registry.current assert len(roundtripped.versions) == len(sample_registry.versions) assert roundtripped.versions[0].version == "1.72.3" @@ -211,7 +214,9 @@ def test_empty_registry(self): class TestBuildManifest: - def test_structure(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_structure(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket blob_names = [ "file_a.h5", "file_b.h5", @@ -223,7 +228,7 @@ def test_structure(self, mock_bucket): make_mock_blob(300), ] - result = build_manifest(mock_bucket, "1.72.3", blob_names) + result = build_manifest("1.72.3", blob_names) assert isinstance(result, VersionManifest) assert result.version == "1.72.3" @@ -236,7 +241,9 @@ def test_structure(self, mock_bucket): assert result.gcs.bucket == "policyengine-us-data" assert result.hf is None - def test_with_subdirectories(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_with_subdirectories(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket blob_names = [ "states/AL.h5", "districts/CA-01.h5", @@ -246,18 +253,24 @@ def test_with_subdirectories(self, mock_bucket): make_mock_blob(222), ] - result = build_manifest(mock_bucket, "1.72.3", blob_names) + result = build_manifest("1.72.3", blob_names) assert "states/AL.h5" in result.gcs.generations assert "districts/CA-01.h5" in result.gcs.generations assert result.gcs.generations["states/AL.h5"] == 111 assert result.gcs.generations["districts/CA-01.h5"] == 222 - def test_with_hf_info(self, mock_bucket, sample_hf_info): + @patch(f"{_MOD}._get_gcs_bucket") + def test_with_hf_info( + self, + mock_get_bucket, + mock_bucket, + sample_hf_info, + ): + mock_get_bucket.return_value = mock_bucket mock_bucket.get_blob.return_value = make_mock_blob(999) result = build_manifest( - mock_bucket, "1.72.3", ["file.h5"], hf_info=sample_hf_info, @@ -267,11 +280,13 @@ def test_with_hf_info(self, mock_bucket, sample_hf_info): assert result.hf.commit == "abc123def456" assert result.hf.repo == ("policyengine/policyengine-us-data") - def test_missing_blob_raises(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_missing_blob_raises(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket mock_bucket.get_blob.return_value = None with pytest.raises(ValueError, match="not found"): - build_manifest(mock_bucket, "1.72.3", ["missing.h5"]) + build_manifest("1.72.3", ["missing.h5"]) # -- upload_manifest tests ----------------------------------------- @@ -280,10 +295,6 @@ def test_missing_blob_raises(self, mock_bucket): class TestUploadManifest: def _setup_empty_registry(self, bucket): """Mock bucket with no existing registry.""" - blob = MagicMock() - blob.download_as_text.side_effect = NotFound("Not found") - # First call reads existing registry (NotFound), - # subsequent calls are for writing written = {} def mock_blob(name): @@ -301,10 +312,19 @@ def mock_blob(name): bucket.blob.side_effect = mock_blob return written - def test_writes_registry_to_gcs(self, mock_bucket, sample_manifest): + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_writes_registry_to_gcs( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket written = self._setup_empty_registry(mock_bucket) - upload_manifest(mock_bucket, sample_manifest) + upload_manifest(sample_manifest) assert "version_manifest.json" in written blob = written["version_manifest.json"] @@ -315,10 +335,19 @@ def test_writes_registry_to_gcs(self, mock_bucket, sample_manifest): assert len(registry_data["versions"]) == 1 assert registry_data["versions"][0]["version"] == "1.72.3" - def test_includes_hf_commit(self, mock_bucket, sample_manifest): + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_includes_hf_commit( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket written = self._setup_empty_registry(mock_bucket) - upload_manifest(mock_bucket, sample_manifest) + upload_manifest(sample_manifest) blob = written["version_manifest.json"] written_json = blob.upload_from_string.call_args[0][0] @@ -326,8 +355,16 @@ def test_includes_hf_commit(self, mock_bucket, sample_manifest): assert registry_data["versions"][0]["hf"]["commit"] == "abc123def456" - def test_appends_to_existing_registry(self, mock_bucket, sample_manifest): - # Pre-populate with an older version + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_appends_to_existing_registry( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket older = VersionManifest( version="1.72.2", created_at="2026-03-09T10:00:00Z", @@ -350,7 +387,7 @@ def mock_blob(name): mock_bucket.blob.side_effect = mock_blob - upload_manifest(mock_bucket, sample_manifest) + upload_manifest(sample_manifest) blob = written["version_manifest.json"] written_json = blob.upload_from_string.call_args[0][0] @@ -358,65 +395,65 @@ def mock_blob(name): assert registry_data["current"] == "1.72.3" assert len(registry_data["versions"]) == 2 - # Most recent first assert registry_data["versions"][0]["version"] == "1.72.3" assert registry_data["versions"][1]["version"] == "1.72.2" - @patch("policyengine_us_data.utils.gcs_version.os") - @patch("policyengine_us_data.utils.gcs_version.HfApi") - def test_uploads_to_hf_when_repo_provided( + @patch(f"{_MOD}.os") + @patch(f"{_MOD}.HfApi") + @patch(f"{_MOD}._get_gcs_bucket") + def test_always_uploads_to_hf( self, + mock_get_bucket, mock_hf_api_cls, mock_os, mock_bucket, sample_manifest, ): + mock_get_bucket.return_value = mock_bucket mock_os.environ.get.return_value = "fake_token" mock_os.unlink = MagicMock() mock_api = MagicMock() mock_hf_api_cls.return_value = mock_api - # Mock GCS read (empty registry) blob = MagicMock() blob.download_as_text.side_effect = NotFound("Not found") mock_bucket.blob.return_value = blob - upload_manifest( - mock_bucket, - sample_manifest, - hf_repo_name=("policyengine/policyengine-us-data"), - ) + upload_manifest(sample_manifest) mock_api.upload_file.assert_called_once() call_kwargs = mock_api.upload_file.call_args.kwargs assert call_kwargs["path_in_repo"] == ("version_manifest.json") assert call_kwargs["repo_id"] == ("policyengine/policyengine-us-data") - def test_skips_hf_when_no_repo(self, mock_bucket, sample_manifest): - self._setup_empty_registry(mock_bucket) - - # No hf_repo_name — should not raise or call HF - upload_manifest(mock_bucket, sample_manifest) - # -- get_current_version tests ------------------------------------- class TestGetCurrentVersion: - def test_returns_version(self, mock_bucket, sample_registry): + @patch(f"{_MOD}._get_gcs_bucket") + def test_returns_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket setup_bucket_with_registry(mock_bucket, sample_registry) - result = get_current_version(mock_bucket) + result = get_current_version() assert result == "1.72.3" mock_bucket.blob.assert_called_with("version_manifest.json") - def test_no_registry_returns_none(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_no_registry_returns_none(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket blob = MagicMock() blob.download_as_text.side_effect = NotFound("Not found") mock_bucket.blob.return_value = blob - result = get_current_version(mock_bucket) + result = get_current_version() assert result is None @@ -425,10 +462,17 @@ def test_no_registry_returns_none(self, mock_bucket): class TestGetManifest: - def test_specific_version(self, mock_bucket, sample_registry): + @patch(f"{_MOD}._get_gcs_bucket") + def test_specific_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket setup_bucket_with_registry(mock_bucket, sample_registry) - result = get_manifest(mock_bucket, "1.72.3") + result = get_manifest("1.72.3") assert isinstance(result, VersionManifest) assert result.version == "1.72.3" @@ -437,26 +481,37 @@ def test_specific_version(self, mock_bucket, sample_registry): result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 ) - def test_nonexistent_version(self, mock_bucket, sample_registry): + @patch(f"{_MOD}._get_gcs_bucket") + def test_nonexistent_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket setup_bucket_with_registry(mock_bucket, sample_registry) with pytest.raises(ValueError, match="not found"): - get_manifest(mock_bucket, "9.9.9") + get_manifest("9.9.9") - def test_no_registry_raises(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_no_registry_raises(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket blob = MagicMock() blob.download_as_text.side_effect = NotFound("Not found") mock_bucket.blob.return_value = blob with pytest.raises(ValueError, match="not found"): - get_manifest(mock_bucket, "1.72.3") + get_manifest("1.72.3") # -- list_versions tests ------------------------------------------- class TestListVersions: - def test_returns_sorted(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_returns_sorted(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket v1 = VersionManifest( version="1.72.1", created_at="t1", @@ -478,15 +533,21 @@ def test_returns_sorted(self, mock_bucket): registry = VersionRegistry(current="1.72.3", versions=[v2, v3, v1]) setup_bucket_with_registry(mock_bucket, registry) - result = list_versions(mock_bucket) + result = list_versions() - assert result == ["1.72.1", "1.72.2", "1.72.3"] + assert result == [ + "1.72.1", + "1.72.2", + "1.72.3", + ] - def test_empty(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_empty(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket registry = VersionRegistry() setup_bucket_with_registry(mock_bucket, registry) - result = list_versions(mock_bucket) + result = list_versions() assert result == [] @@ -495,11 +556,18 @@ def test_empty(self, mock_bucket): class TestDownloadVersionedFile: + @patch(f"{_MOD}._get_gcs_bucket") def test_downloads_correct_generation( - self, mock_bucket, sample_manifest, tmp_path + self, + mock_get_bucket, + mock_bucket, + sample_manifest, + tmp_path, ): + mock_get_bucket.return_value = mock_bucket registry = VersionRegistry( - current="1.72.3", versions=[sample_manifest] + current="1.72.3", + versions=[sample_manifest], ) registry_json = json.dumps(registry.to_dict()) @@ -517,7 +585,6 @@ def mock_blob(name, generation=None): local_path = str(tmp_path / "AL.h5") download_versioned_file( - mock_bucket, "states/AL.h5", "1.72.3", local_path, @@ -535,17 +602,23 @@ def mock_blob(name, generation=None): ] assert len(gen_call) == 1 + @patch(f"{_MOD}._get_gcs_bucket") def test_file_not_in_manifest( - self, mock_bucket, sample_manifest, tmp_path + self, + mock_get_bucket, + mock_bucket, + sample_manifest, + tmp_path, ): + mock_get_bucket.return_value = mock_bucket registry = VersionRegistry( - current="1.72.3", versions=[sample_manifest] + current="1.72.3", + versions=[sample_manifest], ) setup_bucket_with_registry(mock_bucket, registry) with pytest.raises(ValueError, match="not found"): download_versioned_file( - mock_bucket, "nonexistent.h5", "1.72.3", str(tmp_path / "out.h5"), @@ -556,12 +629,14 @@ def test_file_not_in_manifest( class TestRollback: - @patch("policyengine_us_data.utils.gcs_version." "CommitOperationAdd") - @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") - @patch("policyengine_us_data.utils.gcs_version.HfApi") - @patch("policyengine_us_data.utils.gcs_version.os") + @patch(f"{_MOD}.CommitOperationAdd") + @patch(f"{_MOD}.hf_hub_download") + @patch(f"{_MOD}.HfApi") + @patch(f"{_MOD}.os") + @patch(f"{_MOD}._get_gcs_bucket") def test_creates_new_version_with_old_data( self, + mock_get_bucket, mock_os, mock_hf_api_cls, mock_hf_download, @@ -569,19 +644,17 @@ def test_creates_new_version_with_old_data( mock_bucket, sample_manifest, ): + mock_get_bucket.return_value = mock_bucket mock_os.environ.get.return_value = "fake_token" mock_os.path.join = lambda *args: "/".join(args) mock_os.unlink = MagicMock() - # Setup HF mock mock_api = MagicMock() mock_hf_api_cls.return_value = mock_api commit_info = MagicMock() commit_info.oid = "new_commit_sha" mock_api.create_commit.return_value = commit_info - # Setup bucket: get_manifest reads registry, - # upload_manifest reads then writes registry registry = VersionRegistry( current="1.72.3", versions=[sample_manifest], @@ -603,7 +676,6 @@ def mock_blob(name, generation=None): mock_bucket.blob.side_effect = mock_blob - # get_blob returns blobs with new generations new_gen_counter = iter([50001, 50002, 50003]) def mock_get_blob(name): @@ -614,7 +686,6 @@ def mock_get_blob(name): mock_bucket.get_blob.side_effect = mock_get_blob result = rollback( - mock_bucket, target_version="1.72.3", new_version="1.73.0", ) @@ -624,10 +695,8 @@ def mock_get_blob(name): assert result.special_operation == "roll-back" assert result.roll_back_version == "1.72.3" - # GCS files were copied assert mock_bucket.copy_blob.call_count == 3 - # Registry was written with both versions blob = written["version_manifest.json"] written_json = blob.upload_from_string.call_args[0][0] registry_data = json.loads(written_json) @@ -637,21 +706,21 @@ def mock_get_blob(name): assert registry_data["versions"][0]["version"] == "1.73.0" assert registry_data["versions"][0]["special_operation"] == "roll-back" - # HF commit was created and tagged mock_api.create_commit.assert_called_once() commit_msg = mock_api.create_commit.call_args.kwargs["commit_message"] assert "1.72.3" in commit_msg assert "1.73.0" in commit_msg mock_api.create_tag.assert_called_once() - def test_nonexistent_version(self, mock_bucket): + @patch(f"{_MOD}._get_gcs_bucket") + def test_nonexistent_version(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket blob = MagicMock() blob.download_as_text.side_effect = NotFound("Not found") mock_bucket.blob.return_value = blob with pytest.raises(ValueError, match="not found"): rollback( - mock_bucket, target_version="9.9.9", new_version="9.10.0", ) @@ -662,7 +731,7 @@ def test_nonexistent_version(self, mock_bucket): class TestUploadFilesToGcsReturnsGenerations: @patch("policyengine_us_data.utils.data_upload." "google.auth") - @patch("policyengine_us_data.utils.data_upload.storage") + @patch("policyengine_us_data.utils.data_upload." "storage") def test_returns_generations(self, mock_storage, mock_auth, tmp_path): from policyengine_us_data.utils.data_upload import ( upload_files_to_gcs, @@ -704,23 +773,27 @@ def test_returns_generations(self, mock_storage, mock_auth, tmp_path): class TestEndToEndUploadCreatesRegistry: + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") @patch("policyengine_us_data.utils.data_upload." "google.auth") - @patch("policyengine_us_data.utils.data_upload.storage") + @patch("policyengine_us_data.utils.data_upload." "storage") @patch("policyengine_us_data.utils.data_upload.HfApi") @patch("policyengine_us_data.utils.data_upload.os") def test_creates_registry( self, mock_os, mock_hf_api_cls, - mock_storage, - mock_auth, + mock_du_storage, + mock_du_auth, + mock_get_bucket, + mock_hf_upload, tmp_path, ): from policyengine_us_data.utils.data_upload import ( upload_data_files, ) - mock_auth.default.return_value = ( + mock_du_auth.default.return_value = ( MagicMock(), "project", ) @@ -733,10 +806,11 @@ def test_creates_registry( mock_api.create_commit.return_value = commit_info mock_client = MagicMock() - mock_storage.Client.return_value = mock_client + mock_du_storage.Client.return_value = mock_client mock_bucket = MagicMock() mock_bucket.name = "policyengine-us-data" mock_client.bucket.return_value = mock_bucket + mock_get_bucket.return_value = mock_bucket blob_data = MagicMock() blob_data.generation = 55555 @@ -778,29 +852,29 @@ def mock_blob(name): class TestGetDataManifest: def setup_method(self): - import policyengine_us_data.utils.gcs_version as mod + import policyengine_us_data.utils.version_manifest as mod mod._cached_registry = None def teardown_method(self): - import policyengine_us_data.utils.gcs_version as mod + import policyengine_us_data.utils.version_manifest as mod mod._cached_registry = None - @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") + @patch(f"{_MOD}.hf_hub_download") def test_returns_registry(self, mock_download, tmp_path): registry_data = { "current": "1.72.3", "versions": [ { "version": "1.72.3", - "created_at": "2026-03-10T14:30:00Z", + "created_at": ("2026-03-10T14:30:00Z"), "hf": { "repo": ("policyengine/" "policyengine-us-data"), "commit": "abc123", }, "gcs": { - "bucket": "policyengine-us-data", + "bucket": ("policyengine-us-data"), "generations": {"file.h5": 12345}, }, }, @@ -817,19 +891,19 @@ def test_returns_registry(self, mock_download, tmp_path): assert len(result.versions) == 1 assert result.versions[0].hf.commit == "abc123" mock_download.assert_called_once_with( - repo_id="policyengine/policyengine-us-data", + repo_id=("policyengine/policyengine-us-data"), repo_type="model", filename="version_manifest.json", ) - @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") + @patch(f"{_MOD}.hf_hub_download") def test_caches_result(self, mock_download, tmp_path): registry_data = { "current": "1.72.3", "versions": [ { "version": "1.72.3", - "created_at": "2026-03-10T14:30:00Z", + "created_at": ("2026-03-10T14:30:00Z"), "hf": None, "gcs": { "bucket": "b", @@ -848,14 +922,14 @@ def test_caches_result(self, mock_download, tmp_path): assert first is second assert mock_download.call_count == 1 - @patch("policyengine_us_data.utils.gcs_version." "hf_hub_download") + @patch(f"{_MOD}.hf_hub_download") def test_get_data_version(self, mock_download, tmp_path): registry_data = { "current": "1.72.3", "versions": [ { "version": "1.72.3", - "created_at": "2026-03-10T14:30:00Z", + "created_at": ("2026-03-10T14:30:00Z"), "hf": None, "gcs": { "bucket": "b", diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index 6ff7c233c..93a10c639 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -35,13 +35,15 @@ def upload_data_files( hf_repo_type: str = "model", version: Optional[str] = None, ) -> None: - from policyengine_us_data.utils.gcs_version import ( + from policyengine_us_data.utils.version_manifest import ( + GCS_BUCKET_NAME, + HF_REPO_NAME, GCSVersionInfo, HFVersionInfo, VersionManifest, upload_manifest, + _utc_now_iso, ) - from datetime import datetime, timezone if version is None: version = metadata.version("policyengine-us-data") @@ -59,30 +61,16 @@ def upload_data_files( gcs_bucket_name=gcs_bucket_name, ) - # Build and upload version manifest - credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) - bucket = storage_client.bucket(gcs_bucket_name) - manifest = VersionManifest( version=version, - created_at=datetime.now(timezone.utc) - .isoformat() - .replace("+00:00", "Z"), - hf=HFVersionInfo(repo=hf_repo_name, commit=hf_commit), + created_at=_utc_now_iso(), + hf=HFVersionInfo(repo=HF_REPO_NAME, commit=hf_commit), gcs=GCSVersionInfo( - bucket=gcs_bucket_name, + bucket=GCS_BUCKET_NAME, generations=generations, ), ) - upload_manifest( - bucket, - manifest, - hf_repo_name=hf_repo_name, - hf_repo_type=hf_repo_type, - ) + upload_manifest(manifest) logging.info(f"Created version manifest for {version}.") diff --git a/policyengine_us_data/utils/gcs_version.py b/policyengine_us_data/utils/version_manifest.py similarity index 62% rename from policyengine_us_data/utils/gcs_version.py rename to policyengine_us_data/utils/version_manifest.py index c49aaed01..c7d6348a7 100644 --- a/policyengine_us_data/utils/gcs_version.py +++ b/policyengine_us_data/utils/version_manifest.py @@ -1,5 +1,5 @@ """ -GCS version registry for semver-based dataset versioning. +Version registry for semver-based dataset versioning. Provides typed structures and functions for versioned uploads, downloads, and rollbacks across GCS and Hugging Face. All @@ -16,6 +16,7 @@ from pathlib import Path from typing import Any, Optional +import google.auth from google.api_core.exceptions import NotFound from google.cloud import storage from huggingface_hub import ( @@ -24,7 +25,15 @@ hf_hub_download, ) +# -- Configuration ------------------------------------------------- + REGISTRY_BLOB = "version_manifest.json" +GCS_BUCKET_NAME = "policyengine-us-data" +HF_REPO_NAME = "policyengine/policyengine-us-data" +HF_REPO_TYPE = "model" + + +# -- Types --------------------------------------------------------- @dataclass @@ -140,7 +149,8 @@ def get_version(self, version: str) -> VersionManifest: The matching VersionManifest. Raises: - ValueError: If the version is not in the registry. + ValueError: If the version is not in the + registry. """ for v in self.versions: if v.version == version: @@ -152,49 +162,19 @@ def get_version(self, version: str) -> VersionManifest: ) -def build_manifest( - bucket: storage.Bucket, - version: str, - blob_names: list[str], - hf_info: Optional[HFVersionInfo] = None, -) -> VersionManifest: - """Build a version manifest by reading generation numbers - from uploaded blobs. +# -- Internal helpers ---------------------------------------------- - Args: - bucket: GCS bucket containing the uploaded blobs. - version: Semver version string. - blob_names: List of blob paths to include in the - manifest. - hf_info: Optional HF backend info to include. - Returns: - A VersionManifest with generation numbers for each blob. - """ - generations: dict[str, int] = {} - for name in blob_names: - blob = bucket.get_blob(name) - if blob is None: - raise ValueError( - f"Blob '{name}' not found in bucket " - f"'{bucket.name}' after upload." - ) - generations[name] = blob.generation - - return VersionManifest( - version=version, - created_at=datetime.now(timezone.utc) - .isoformat() - .replace("+00:00", "Z"), - hf=hf_info, - gcs=GCSVersionInfo( - bucket=bucket.name, - generations=generations, - ), - ) +def _utc_now_iso() -> str: + """Return the current UTC time as an ISO 8601 string.""" + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") -# -- Registry I/O ------------------------------------------------- +def _get_gcs_bucket() -> storage.Bucket: + """Return an authenticated GCS bucket handle.""" + credentials, project_id = google.auth.default() + client = storage.Client(credentials=credentials, project=project_id) + return client.bucket(GCS_BUCKET_NAME) def _read_registry_from_gcs( @@ -220,13 +200,11 @@ def _upload_registry_to_gcs( data = json.dumps(registry.to_dict(), indent=2) blob = bucket.blob(REGISTRY_BLOB) blob.upload_from_string(data, content_type="application/json") - logging.info(f"Uploaded registry to GCS " f"(current={registry.current}).") + logging.info("Uploaded registry to GCS " f"(current={registry.current}).") def _upload_registry_to_hf( registry: VersionRegistry, - hf_repo_name: str, - hf_repo_type: str, ) -> None: """Write the version registry to Hugging Face.""" token = os.environ.get("HUGGING_FACE_TOKEN") @@ -243,78 +221,195 @@ def _upload_registry_to_hf( api.upload_file( path_or_fileobj=tmp_path, path_in_repo=REGISTRY_BLOB, - repo_id=hf_repo_name, - repo_type=hf_repo_type, + repo_id=HF_REPO_NAME, + repo_type=HF_REPO_TYPE, token=token, commit_message=( - f"Update version registry " f"(current={registry.current})" + "Update version registry " f"(current={registry.current})" ), ) logging.info( - f"Uploaded {REGISTRY_BLOB} to " f"HF repo {hf_repo_name}." + f"Uploaded {REGISTRY_BLOB} to " f"HF repo {HF_REPO_NAME}." ) finally: os.unlink(tmp_path) -def upload_manifest( +def _restore_gcs_generations( bucket: storage.Bucket, + old_generations: dict[str, int], +) -> dict[str, int]: + """Copy old GCS generation blobs to live paths. + + Args: + bucket: GCS bucket containing the blobs. + old_generations: Map of blob path to old generation + number. + + Returns: + Map of blob path to new generation number. + """ + new_generations: dict[str, int] = {} + for file_path, generation in old_generations.items(): + source_blob = bucket.blob(file_path, generation=generation) + bucket.copy_blob(source_blob, bucket, file_path) + restored_blob = bucket.get_blob(file_path) + new_generations[file_path] = restored_blob.generation + logging.info( + f"Restored {file_path}: generation " + f"{generation} -> {restored_blob.generation}." + ) + return new_generations + + +def _restore_hf_commit( + old_manifest: VersionManifest, + new_version: str, +) -> str: + """Re-upload old HF data as a new commit and tag it. + + Args: + old_manifest: The manifest of the version being + restored. + new_version: The new semver version string for + tagging. + + Returns: + The commit SHA of the new HF commit. + """ + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + target_version = old_manifest.version + + operations = [] + with tempfile.TemporaryDirectory() as tmpdir: + for file_path in old_manifest.gcs.generations: + hf_hub_download( + repo_id=old_manifest.hf.repo, + repo_type=HF_REPO_TYPE, + filename=file_path, + revision=old_manifest.hf.commit, + local_dir=tmpdir, + token=token, + ) + downloaded = os.path.join(tmpdir, file_path) + operations.append( + CommitOperationAdd( + path_in_repo=file_path, + path_or_fileobj=downloaded, + ) + ) + + commit_info = api.create_commit( + token=token, + repo_id=HF_REPO_NAME, + operations=operations, + repo_type=HF_REPO_TYPE, + commit_message=( + f"Roll back to {target_version} " f"as {new_version}" + ), + ) + + try: + api.create_tag( + token=token, + repo_id=HF_REPO_NAME, + tag=new_version, + revision=commit_info.oid, + repo_type=HF_REPO_TYPE, + ) + except Exception as e: + if "already exists" in str(e) or "409" in str(e): + logging.warning( + f"Tag {new_version} already exists. " "Skipping tag creation." + ) + else: + raise + + return commit_info.oid + + +# -- Public API ---------------------------------------------------- + + +def build_manifest( + version: str, + blob_names: list[str], + hf_info: Optional[HFVersionInfo] = None, +) -> VersionManifest: + """Build a version manifest by reading generation + numbers from uploaded blobs. + + Args: + version: Semver version string. + blob_names: List of blob paths to include. + hf_info: Optional HF backend info to include. + + Returns: + A VersionManifest with generation numbers for + each blob. + """ + bucket = _get_gcs_bucket() + generations: dict[str, int] = {} + for name in blob_names: + blob = bucket.get_blob(name) + if blob is None: + raise ValueError( + f"Blob '{name}' not found in bucket " + f"'{bucket.name}' after upload." + ) + generations[name] = blob.generation + + return VersionManifest( + version=version, + created_at=_utc_now_iso(), + hf=hf_info, + gcs=GCSVersionInfo( + bucket=bucket.name, + generations=generations, + ), + ) + + +def upload_manifest( manifest: VersionManifest, - hf_repo_name: Optional[str] = None, - hf_repo_type: str = "model", ) -> None: - """Append a version manifest to the registry and upload. + """Append a version manifest to the registry and + upload to both GCS and HF. Reads the existing registry from GCS (or starts fresh), prepends the new manifest, updates the current pointer, - and writes the registry to GCS and optionally HF. + and writes the registry to both backends. Args: - bucket: GCS bucket to upload to. manifest: The version manifest to add. - hf_repo_name: If provided, also upload to this - HF repo. - hf_repo_type: HF repository type. """ + bucket = _get_gcs_bucket() registry = _read_registry_from_gcs(bucket) registry.versions.insert(0, manifest) registry.current = manifest.version - _upload_registry_to_gcs(bucket, registry) + _upload_registry_to_hf(registry) - if hf_repo_name is not None: - _upload_registry_to_hf(registry, hf_repo_name, hf_repo_type) - - -# -- Query functions ----------------------------------------------- - -def get_current_version( - bucket: storage.Bucket, -) -> Optional[str]: - """Get the current version of the bucket. - - Args: - bucket: GCS bucket to query. +def get_current_version() -> Optional[str]: + """Get the current version from the registry. Returns: The current semver version string, or None if no registry exists. """ + bucket = _get_gcs_bucket() registry = _read_registry_from_gcs(bucket) if not registry.current: return None return registry.current -def get_manifest( - bucket: storage.Bucket, - version: str, -) -> VersionManifest: +def get_manifest(version: str) -> VersionManifest: """Get the manifest for a specific version. Args: - bucket: GCS bucket to query. version: Semver version string. Returns: @@ -323,27 +418,23 @@ def get_manifest( Raises: ValueError: If the version is not in the registry. """ + bucket = _get_gcs_bucket() registry = _read_registry_from_gcs(bucket) return registry.get_version(version) -def list_versions( - bucket: storage.Bucket, -) -> list[str]: - """List all available versions in the bucket. - - Args: - bucket: GCS bucket to query. +def list_versions() -> list[str]: + """List all available versions. Returns: Sorted list of semver version strings. """ + bucket = _get_gcs_bucket() registry = _read_registry_from_gcs(bucket) return sorted(v.version for v in registry.versions) def download_versioned_file( - bucket: storage.Bucket, file_path: str, version: str, local_path: str, @@ -351,7 +442,6 @@ def download_versioned_file( """Download a specific file at a specific version. Args: - bucket: GCS bucket to download from. file_path: Path of the file within the bucket. version: Semver version string. local_path: Local path to save the file to. @@ -362,13 +452,16 @@ def download_versioned_file( Raises: ValueError: If the version or file is not found. """ - manifest = get_manifest(bucket, version) + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + manifest = registry.get_version(version) if file_path not in manifest.gcs.generations: raise ValueError( - f"File '{file_path}' not found in manifest for " - f"version '{version}'. Available files: " - f"{list(manifest.gcs.generations.keys())[:10]}..." + f"File '{file_path}' not found in manifest " + f"for version '{version}'. Available files: " + f"{list(manifest.gcs.generations.keys())[:10]}" + "..." ) generation = manifest.gcs.generations[file_path] @@ -385,157 +478,74 @@ def download_versioned_file( def rollback( - bucket: storage.Bucket, target_version: str, new_version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", ) -> VersionManifest: """Roll back by releasing a new version with old data. - This treats rollback as a new release: data from - target_version is copied to the live paths (creating new - GCS generations), a new HF commit is created with the - old data, and a new manifest is published under + Treats rollback as a new release: data from + target_version is copied to the live paths (creating + new GCS generations), a new HF commit is created with + the old data, and a new manifest is published under new_version with special_operation="roll-back". Args: - bucket: GCS bucket to roll back. target_version: Semver version to roll back to. - new_version: New semver version to publish - (e.g., "1.73.0"). - hf_repo_name: HuggingFace repository name. - hf_repo_type: HuggingFace repository type. + new_version: New semver version to publish. Returns: The new VersionManifest for the rollback release. Raises: - ValueError: If target_version is not in the registry. + ValueError: If target_version is not in the + registry. """ - old_manifest = get_manifest(bucket, target_version) - - # 1. Restore GCS files by copying old generations - # to live paths, then record new generations. - new_generations: dict[str, int] = {} - for file_path, generation in old_manifest.gcs.generations.items(): - source_blob = bucket.blob(file_path, generation=generation) - bucket.copy_blob(source_blob, bucket, file_path) - # Read back the new generation - restored_blob = bucket.get_blob(file_path) - new_generations[file_path] = restored_blob.generation - logging.info( - f"Restored {file_path}: generation " - f"{generation} -> {restored_blob.generation}." - ) - - # 2. Re-upload old data to HF as a new commit - hf_commit = None - if old_manifest.hf is not None: - token = os.environ.get("HUGGING_FACE_TOKEN") - api = HfApi() - - operations = [] - with tempfile.TemporaryDirectory() as tmpdir: - for file_path in old_manifest.gcs.generations.keys(): - local_path = os.path.join(tmpdir, file_path.replace("/", "_")) - hf_hub_download( - repo_id=old_manifest.hf.repo, - repo_type=hf_repo_type, - filename=file_path, - revision=old_manifest.hf.commit, - local_dir=tmpdir, - token=token, - ) - downloaded = os.path.join(tmpdir, file_path) - operations.append( - CommitOperationAdd( - path_in_repo=file_path, - path_or_fileobj=downloaded, - ) - ) - - commit_info = api.create_commit( - token=token, - repo_id=hf_repo_name, - operations=operations, - repo_type=hf_repo_type, - commit_message=( - f"Roll back to {target_version} " f"as {new_version}" - ), - ) - hf_commit = commit_info.oid - - # Tag the new commit - try: - api.create_tag( - token=token, - repo_id=hf_repo_name, - tag=new_version, - revision=hf_commit, - repo_type=hf_repo_type, - ) - except Exception as e: - if "already exists" in str(e) or "409" in str(e): - logging.warning( - f"Tag {new_version} already exists. " - f"Skipping tag creation." - ) - else: - raise + bucket = _get_gcs_bucket() + old_manifest = _read_registry_from_gcs(bucket).get_version(target_version) + + new_gens = _restore_gcs_generations(bucket, old_manifest.gcs.generations) + hf_commit = ( + _restore_hf_commit(old_manifest, new_version) + if old_manifest.hf + else None + ) - # 3. Build and upload the new manifest - new_manifest = VersionManifest( + manifest = VersionManifest( version=new_version, - created_at=datetime.now(timezone.utc) - .isoformat() - .replace("+00:00", "Z"), + created_at=_utc_now_iso(), hf=( - HFVersionInfo(repo=hf_repo_name, commit=hf_commit) + HFVersionInfo(repo=HF_REPO_NAME, commit=hf_commit) if hf_commit else None ), gcs=GCSVersionInfo( - bucket=bucket.name, - generations=new_generations, + bucket=GCS_BUCKET_NAME, + generations=new_gens, ), special_operation="roll-back", roll_back_version=target_version, ) - upload_manifest( - bucket, - new_manifest, - hf_repo_name=hf_repo_name, - hf_repo_type=hf_repo_type, - ) + upload_manifest(manifest) logging.info( f"Rolled back to {target_version} as new " f"version {new_version}. " - f"Restored {len(new_generations)} files." + f"Restored {len(new_gens)} files." ) - return new_manifest + return manifest # -- Consumer API -------------------------------------------------- - _cached_registry: Optional[VersionRegistry] = None -def get_data_manifest( - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", -) -> VersionRegistry: +def get_data_manifest() -> VersionRegistry: """Get the full version registry from HF. - Fetches version_manifest.json from the Hugging Face repo - and returns it as a VersionRegistry. The result is cached - in memory after the first call. - - Args: - hf_repo_name: HF repository name. - hf_repo_type: HF repository type. + Fetches version_manifest.json from the Hugging Face + repo and returns it as a VersionRegistry. The result + is cached in memory after the first call. Returns: The full VersionRegistry. @@ -545,8 +555,8 @@ def get_data_manifest( return _cached_registry local_path = hf_hub_download( - repo_id=hf_repo_name, - repo_type=hf_repo_type, + repo_id=HF_REPO_NAME, + repo_type=HF_REPO_TYPE, filename=REGISTRY_BLOB, ) with open(local_path) as f: @@ -556,19 +566,12 @@ def get_data_manifest( return _cached_registry -def get_data_version( - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", -) -> str: +def get_data_version() -> str: """Get the current deployed data version string. Convenience wrapper around get_data_manifest(). - Args: - hf_repo_name: HF repository name. - hf_repo_type: HF repository type. - Returns: The current semver version string. """ - return get_data_manifest(hf_repo_name, hf_repo_type).current + return get_data_manifest().current From 38314b46fbfcf1b57e11adce676e1519c64cdab8 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Mar 2026 20:29:36 +0100 Subject: [PATCH 4/8] Add changelog fragment for version manifest feature Co-Authored-By: Claude Opus 4.6 --- changelog.d/added/601.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/added/601.md diff --git a/changelog.d/added/601.md b/changelog.d/added/601.md new file mode 100644 index 000000000..968c338b1 --- /dev/null +++ b/changelog.d/added/601.md @@ -0,0 +1 @@ +Add unified version manifest system for semver-based dataset versioning across GCS and Hugging Face, with rollback support and a single registry file (version_manifest.json) on both backends. From a4dbe4ab380ad1151e8511c201fe73d82c6fef1d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Mar 2026 20:51:23 +0100 Subject: [PATCH 5/8] Fix ruff lint: move fixtures to conftest, remove unused imports Move pytest fixtures to conftest.py for auto-discovery (fixes F401/F811), remove unused imports from data_upload.py, apply ruff formatting. Co-Authored-By: Claude Opus 4.6 --- policyengine_us_data/tests/conftest.py | 63 +++++++++++++++++++ .../tests/fixtures/test_version_manifest.py | 58 +---------------- .../tests/test_version_manifest.py | 33 +++------- policyengine_us_data/utils/data_upload.py | 23 ++----- .../utils/version_manifest.py | 35 +++-------- 5 files changed, 88 insertions(+), 124 deletions(-) create mode 100644 policyengine_us_data/tests/conftest.py diff --git a/policyengine_us_data/tests/conftest.py b/policyengine_us_data/tests/conftest.py new file mode 100644 index 000000000..fb39787c3 --- /dev/null +++ b/policyengine_us_data/tests/conftest.py @@ -0,0 +1,63 @@ +"""Shared fixtures for version manifest tests.""" + +from unittest.mock import MagicMock + +import pytest + +from policyengine_us_data.utils.version_manifest import ( + HFVersionInfo, + GCSVersionInfo, + VersionManifest, + VersionRegistry, +) + + +@pytest.fixture +def sample_generations() -> dict[str, int]: + return { + "enhanced_cps_2024.h5": 1710203948123456, + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + } + + +@pytest.fixture +def sample_hf_info() -> HFVersionInfo: + return HFVersionInfo( + repo="policyengine/policyengine-us-data", + commit="abc123def456", + ) + + +@pytest.fixture +def sample_manifest( + sample_generations: dict[str, int], + sample_hf_info: HFVersionInfo, +) -> VersionManifest: + return VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + + +@pytest.fixture +def sample_registry( + sample_manifest: VersionManifest, +) -> VersionRegistry: + """A registry with one version entry.""" + return VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + + +@pytest.fixture +def mock_bucket() -> MagicMock: + bucket = MagicMock() + bucket.name = "policyengine-us-data" + return bucket diff --git a/policyengine_us_data/tests/fixtures/test_version_manifest.py b/policyengine_us_data/tests/fixtures/test_version_manifest.py index a6158dc76..2678f0315 100644 --- a/policyengine_us_data/tests/fixtures/test_version_manifest.py +++ b/policyengine_us_data/tests/fixtures/test_version_manifest.py @@ -1,69 +1,13 @@ -"""Shared fixtures for version manifest tests.""" +"""Helper functions for version manifest tests.""" import json from unittest.mock import MagicMock -import pytest - from policyengine_us_data.utils.version_manifest import ( - HFVersionInfo, - GCSVersionInfo, - VersionManifest, VersionRegistry, ) -@pytest.fixture -def sample_generations() -> dict[str, int]: - return { - "enhanced_cps_2024.h5": 1710203948123456, - "cps_2024.h5": 1710203948234567, - "states/AL.h5": 1710203948345678, - } - - -@pytest.fixture -def sample_hf_info() -> HFVersionInfo: - return HFVersionInfo( - repo="policyengine/policyengine-us-data", - commit="abc123def456", - ) - - -@pytest.fixture -def sample_manifest( - sample_generations: dict[str, int], - sample_hf_info: HFVersionInfo, -) -> VersionManifest: - return VersionManifest( - version="1.72.3", - created_at="2026-03-10T14:30:00Z", - hf=sample_hf_info, - gcs=GCSVersionInfo( - bucket="policyengine-us-data", - generations=sample_generations, - ), - ) - - -@pytest.fixture -def sample_registry( - sample_manifest: VersionManifest, -) -> VersionRegistry: - """A registry with one version entry.""" - return VersionRegistry( - current="1.72.3", - versions=[sample_manifest], - ) - - -@pytest.fixture -def mock_bucket() -> MagicMock: - bucket = MagicMock() - bucket.name = "policyengine-us-data" - return bucket - - def make_mock_blob(generation: int) -> MagicMock: blob = MagicMock() blob.generation = generation diff --git a/policyengine_us_data/tests/test_version_manifest.py b/policyengine_us_data/tests/test_version_manifest.py index c6c6b5058..9af6d9058 100644 --- a/policyengine_us_data/tests/test_version_manifest.py +++ b/policyengine_us_data/tests/test_version_manifest.py @@ -7,7 +7,6 @@ from google.api_core.exceptions import NotFound from policyengine_us_data.utils.version_manifest import ( - HFVersionInfo, GCSVersionInfo, VersionManifest, VersionRegistry, @@ -22,11 +21,6 @@ get_data_version, ) from policyengine_us_data.tests.fixtures.test_version_manifest import ( - sample_generations, - sample_hf_info, - sample_manifest, - sample_registry, - mock_bucket, make_mock_blob, setup_bucket_with_registry, ) @@ -46,10 +40,7 @@ def test_to_dict(self, sample_manifest): assert result["hf"]["repo"] == ("policyengine/policyengine-us-data") assert result["hf"]["commit"] == "abc123def456" assert result["gcs"]["bucket"] == "policyengine-us-data" - assert ( - result["gcs"]["generations"]["enhanced_cps_2024.h5"] - == 1710203948123456 - ) + assert result["gcs"]["generations"]["enhanced_cps_2024.h5"] == 1710203948123456 def test_from_dict(self, sample_manifest): data = { @@ -73,9 +64,7 @@ def test_from_dict(self, sample_manifest): assert result.version == "1.72.3" assert result.hf.commit == "abc123def456" assert result.hf.repo == ("policyengine/policyengine-us-data") - assert ( - result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 - ) + assert result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 assert result.gcs.bucket == "policyengine-us-data" def test_roundtrip(self, sample_manifest): @@ -128,9 +117,7 @@ def test_special_operation_included_when_set( assert data["special_operation"] == "roll-back" assert data["roll_back_version"] == "1.70.1" - def test_special_operation_roundtrip( - self, sample_generations, sample_hf_info - ): + def test_special_operation_roundtrip(self, sample_generations, sample_hf_info): manifest = VersionManifest( version="1.73.0", created_at="2026-03-10T15:00:00Z", @@ -477,9 +464,7 @@ def test_specific_version( assert isinstance(result, VersionManifest) assert result.version == "1.72.3" assert result.hf.commit == "abc123def456" - assert ( - result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 - ) + assert result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 @patch(f"{_MOD}._get_gcs_bucket") def test_nonexistent_version( @@ -730,8 +715,8 @@ def test_nonexistent_version(self, mock_get_bucket, mock_bucket): class TestUploadFilesToGcsReturnsGenerations: - @patch("policyengine_us_data.utils.data_upload." "google.auth") - @patch("policyengine_us_data.utils.data_upload." "storage") + @patch("policyengine_us_data.utils.data_upload.google.auth") + @patch("policyengine_us_data.utils.data_upload.storage") def test_returns_generations(self, mock_storage, mock_auth, tmp_path): from policyengine_us_data.utils.data_upload import ( upload_files_to_gcs, @@ -775,8 +760,8 @@ def test_returns_generations(self, mock_storage, mock_auth, tmp_path): class TestEndToEndUploadCreatesRegistry: @patch(f"{_MOD}._upload_registry_to_hf") @patch(f"{_MOD}._get_gcs_bucket") - @patch("policyengine_us_data.utils.data_upload." "google.auth") - @patch("policyengine_us_data.utils.data_upload." "storage") + @patch("policyengine_us_data.utils.data_upload.google.auth") + @patch("policyengine_us_data.utils.data_upload.storage") @patch("policyengine_us_data.utils.data_upload.HfApi") @patch("policyengine_us_data.utils.data_upload.os") def test_creates_registry( @@ -870,7 +855,7 @@ def test_returns_registry(self, mock_download, tmp_path): "version": "1.72.3", "created_at": ("2026-03-10T14:30:00Z"), "hf": { - "repo": ("policyengine/" "policyengine-us-data"), + "repo": ("policyengine/policyengine-us-data"), "commit": "abc123", }, "gcs": { diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index 93a10c639..a3821d6c6 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -5,13 +5,11 @@ CommitOperationCopy, CommitOperationDelete, ) -from huggingface_hub.errors import RevisionNotFoundError from google.cloud import storage from pathlib import Path from importlib import metadata import google.auth import httpx -import json import logging import os @@ -146,9 +144,7 @@ def upload_files_to_gcs( Dict mapping blob name to its GCS generation number. """ credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) generations: Dict[str, int] = {} @@ -156,9 +152,7 @@ def upload_files_to_gcs( file_path = Path(file_path) blob = bucket.blob(file_path.name) blob.upload_from_filename(file_path) - logging.info( - f"Uploaded {file_path.name} to GCS bucket " f"{gcs_bucket_name}." - ) + logging.info(f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}.") # Set metadata blob.metadata = {"version": version} @@ -207,9 +201,7 @@ def upload_local_area_file( # Upload to GCS with subdirectory credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) blob_name = f"{subdirectory}/{file_path.name}" @@ -237,12 +229,11 @@ def upload_local_area_file( repo_type=hf_repo_type, token=token, commit_message=( - f"Upload {subdirectory}/{file_path.name} " f"for version {version}" + f"Upload {subdirectory}/{file_path.name} for version {version}" ), ) logging.info( - f"Uploaded {subdirectory}/{file_path.name} to " - f"Hugging Face {hf_repo_name}." + f"Uploaded {subdirectory}/{file_path.name} to Hugging Face {hf_repo_name}." ) return generation @@ -389,9 +380,7 @@ def upload_to_staging_hf( f"Uploaded batch {i // batch_size + 1}: {len(operations)} files to staging/" ) - logging.info( - f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace" - ) + logging.info(f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace") return total_uploaded diff --git a/policyengine_us_data/utils/version_manifest.py b/policyengine_us_data/utils/version_manifest.py index c7d6348a7..37834ff39 100644 --- a/policyengine_us_data/utils/version_manifest.py +++ b/policyengine_us_data/utils/version_manifest.py @@ -200,7 +200,7 @@ def _upload_registry_to_gcs( data = json.dumps(registry.to_dict(), indent=2) blob = bucket.blob(REGISTRY_BLOB) blob.upload_from_string(data, content_type="application/json") - logging.info("Uploaded registry to GCS " f"(current={registry.current}).") + logging.info(f"Uploaded registry to GCS (current={registry.current}).") def _upload_registry_to_hf( @@ -211,9 +211,7 @@ def _upload_registry_to_hf( api = HfApi() data = json.dumps(registry.to_dict(), indent=2) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write(data) tmp_path = f.name @@ -224,13 +222,9 @@ def _upload_registry_to_hf( repo_id=HF_REPO_NAME, repo_type=HF_REPO_TYPE, token=token, - commit_message=( - "Update version registry " f"(current={registry.current})" - ), - ) - logging.info( - f"Uploaded {REGISTRY_BLOB} to " f"HF repo {HF_REPO_NAME}." + commit_message=(f"Update version registry (current={registry.current})"), ) + logging.info(f"Uploaded {REGISTRY_BLOB} to HF repo {HF_REPO_NAME}.") finally: os.unlink(tmp_path) @@ -305,9 +299,7 @@ def _restore_hf_commit( repo_id=HF_REPO_NAME, operations=operations, repo_type=HF_REPO_TYPE, - commit_message=( - f"Roll back to {target_version} " f"as {new_version}" - ), + commit_message=(f"Roll back to {target_version} as {new_version}"), ) try: @@ -320,9 +312,7 @@ def _restore_hf_commit( ) except Exception as e: if "already exists" in str(e) or "409" in str(e): - logging.warning( - f"Tag {new_version} already exists. " "Skipping tag creation." - ) + logging.warning(f"Tag {new_version} already exists. Skipping tag creation.") else: raise @@ -355,8 +345,7 @@ def build_manifest( blob = bucket.get_blob(name) if blob is None: raise ValueError( - f"Blob '{name}' not found in bucket " - f"'{bucket.name}' after upload." + f"Blob '{name}' not found in bucket '{bucket.name}' after upload." ) generations[name] = blob.generation @@ -505,19 +494,13 @@ def rollback( new_gens = _restore_gcs_generations(bucket, old_manifest.gcs.generations) hf_commit = ( - _restore_hf_commit(old_manifest, new_version) - if old_manifest.hf - else None + _restore_hf_commit(old_manifest, new_version) if old_manifest.hf else None ) manifest = VersionManifest( version=new_version, created_at=_utc_now_iso(), - hf=( - HFVersionInfo(repo=HF_REPO_NAME, commit=hf_commit) - if hf_commit - else None - ), + hf=(HFVersionInfo(repo=HF_REPO_NAME, commit=hf_commit) if hf_commit else None), gcs=GCSVersionInfo( bucket=GCS_BUCKET_NAME, generations=new_gens, From c5ee44b0e4a89e980a3f3977f26b50ab848d1e23 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 12 Mar 2026 16:04:59 -0400 Subject: [PATCH 6/8] Apply ruff formatting to data_upload.py Co-Authored-By: Claude Opus 4.6 --- policyengine_us_data/utils/data_upload.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index 4e4b96d9a..f585e58fa 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -145,9 +145,7 @@ def upload_files_to_gcs( Dict mapping blob name to its GCS generation number. """ credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) generations: Dict[str, int] = {} @@ -155,9 +153,7 @@ def upload_files_to_gcs( file_path = Path(file_path) blob = bucket.blob(file_path.name) blob.upload_from_filename(file_path) - logging.info( - f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}." - ) + logging.info(f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}.") # Set metadata blob.metadata = {"version": version} @@ -205,9 +201,7 @@ def upload_local_area_file( # Upload to GCS with subdirectory credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) blob_name = f"{subdirectory}/{file_path.name}" @@ -386,9 +380,7 @@ def upload_to_staging_hf( f"Uploaded batch {i // batch_size + 1}: {len(operations)} files to staging/" ) - logging.info( - f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace" - ) + logging.info(f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace") return total_uploaded @@ -539,9 +531,7 @@ def upload_from_hf_staging_to_gcs( token = os.environ.get("HUGGING_FACE_TOKEN") credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) uploaded = 0 From 8438d74b79c3da2597e4643b9da43d102a3a7ed3 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 14 Mar 2026 01:52:21 +0100 Subject: [PATCH 7/8] Address PR review: fix semver sort, use exist_ok, add tests, dedupe constants - Fix list_versions() to use packaging.version.Version for proper semver ordering instead of lexicographic string sort - Replace broad Exception catch with exist_ok=True on HfApi.create_tag() in both version_manifest.py and data_upload.py - Add test for rollback when hf=None - Add semver sorting edge-case test (1.9.0 vs 1.10.0 vs 1.100.0) - Deduplicate GCS/HF constants by importing from version_manifest - Update changelog fragment to mention consumer APIs Co-Authored-By: Claude Opus 4.6 --- changelog.d/added/601.md | 2 +- .../tests/test_version_manifest.py | 95 +++++++++++++++++++ policyengine_us_data/utils/data_upload.py | 71 +++++++------- .../utils/version_manifest.py | 27 +++--- pyproject.toml | 1 + 5 files changed, 143 insertions(+), 53 deletions(-) diff --git a/changelog.d/added/601.md b/changelog.d/added/601.md index 968c338b1..2148d164e 100644 --- a/changelog.d/added/601.md +++ b/changelog.d/added/601.md @@ -1 +1 @@ -Add unified version manifest system for semver-based dataset versioning across GCS and Hugging Face, with rollback support and a single registry file (version_manifest.json) on both backends. +Add unified version manifest system for semver-based dataset versioning across GCS and Hugging Face, with rollback support and a single registry file (version_manifest.json) on both backends. Exposes `get_data_version()` and `get_data_manifest()` as public consumer APIs. diff --git a/policyengine_us_data/tests/test_version_manifest.py b/policyengine_us_data/tests/test_version_manifest.py index 9af6d9058..b1a150bf9 100644 --- a/policyengine_us_data/tests/test_version_manifest.py +++ b/policyengine_us_data/tests/test_version_manifest.py @@ -536,6 +536,42 @@ def test_empty(self, mock_get_bucket, mock_bucket): assert result == [] + @patch(f"{_MOD}._get_gcs_bucket") + def test_semver_ordering(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + versions = [ + "1.100.0", + "2.0.0", + "1.9.0", + "1.10.0", + ] + manifests = [ + VersionManifest( + version=v, + created_at="t", + hf=None, + gcs=GCSVersionInfo( + bucket="b", + generations={"f.h5": i}, + ), + ) + for i, v in enumerate(versions) + ] + registry = VersionRegistry( + current="2.0.0", + versions=manifests, + ) + setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions() + + assert result == [ + "1.9.0", + "1.10.0", + "1.100.0", + "2.0.0", + ] + # -- download_versioned_file tests --------------------------------- @@ -697,6 +733,65 @@ def mock_get_blob(name): assert "1.73.0" in commit_msg mock_api.create_tag.assert_called_once() + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_rollback_without_hf( + self, + mock_get_bucket, + mock_hf_upload, + mock_bucket, + ): + mock_get_bucket.return_value = mock_bucket + + old_manifest = VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations={"file.h5": 111}, + ), + ) + registry = VersionRegistry( + current="1.72.3", + versions=[old_manifest], + ) + registry_json = json.dumps(registry.to_dict()) + written = {} + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.return_value = registry_json + written[name] = b + return b + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + restored_blob = MagicMock() + restored_blob.generation = 222 + mock_bucket.get_blob.return_value = restored_blob + + result = rollback( + target_version="1.72.3", + new_version="1.73.0", + ) + + assert result.version == "1.73.0" + assert result.hf is None + assert result.special_operation == "roll-back" + assert mock_bucket.copy_blob.call_count == 1 + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + assert registry_data["versions"][0]["hf"] is None + @patch(f"{_MOD}._get_gcs_bucket") def test_nonexistent_version(self, mock_get_bucket, mock_bucket): mock_get_bucket.return_value = mock_bucket diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index f585e58fa..2dfb24ed0 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -14,6 +14,11 @@ import logging import os +from policyengine_us_data.utils.version_manifest import ( + GCS_BUCKET_NAME, + HF_REPO_NAME, + HF_REPO_TYPE, +) from tenacity import ( retry, stop_after_attempt, @@ -29,14 +34,12 @@ def upload_data_files( files: List[str], - gcs_bucket_name: str = "policyengine-us-data", - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + gcs_bucket_name: str = GCS_BUCKET_NAME, + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, version: Optional[str] = None, ) -> None: from policyengine_us_data.utils.version_manifest import ( - GCS_BUCKET_NAME, - HF_REPO_NAME, GCSVersionInfo, HFVersionInfo, VersionManifest, @@ -76,8 +79,8 @@ def upload_data_files( def upload_files_to_hf( files: List[str], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, ) -> str: """Upload files to Hugging Face repository and tag the commit with the version. @@ -110,25 +113,17 @@ def upload_files_to_hf( ) logging.info(f"Uploaded files to Hugging Face repository {hf_repo_name}.") - # Tag commit with version (convenience for HF web UI) - try: - api.create_tag( - token=token, - repo_id=hf_repo_name, - tag=version, - revision=commit_info.oid, - repo_type=hf_repo_type, - ) - logging.info( - f"Tagged commit with {version} in Hugging Face repository {hf_repo_name}." - ) - except Exception as e: - if "Tag reference exists already" in str(e) or "409" in str(e): - logging.warning( - f"Tag {version} already exists in {hf_repo_name}. Skipping tag creation." - ) - else: - raise + api.create_tag( + token=token, + repo_id=hf_repo_name, + tag=version, + revision=commit_info.oid, + repo_type=hf_repo_type, + exist_ok=True, + ) + logging.info( + f"Tagged commit with {version} in Hugging Face repository {hf_repo_name}." + ) return commit_info.oid @@ -136,7 +131,7 @@ def upload_files_to_hf( def upload_files_to_gcs( files: List[str], version: str, - gcs_bucket_name: str = "policyengine-us-data", + gcs_bucket_name: str = GCS_BUCKET_NAME, ) -> Dict[str, int]: """Upload files to Google Cloud Storage and set metadata with the version. @@ -174,9 +169,9 @@ def upload_files_to_gcs( def upload_local_area_file( file_path: str, subdirectory: str, - gcs_bucket_name: str = "policyengine-us-data", - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + gcs_bucket_name: str = GCS_BUCKET_NAME, + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, version: str = None, skip_hf: bool = False, ) -> int: @@ -241,8 +236,8 @@ def upload_local_area_file( def upload_local_area_batch_to_hf( files_with_subdirs: List[tuple], - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, version: str = None, ): """ @@ -327,8 +322,8 @@ def hf_create_commit_with_retry( def upload_to_staging_hf( files_with_paths: List[Tuple[Path, str]], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, batch_size: int = 50, ) -> int: """ @@ -387,8 +382,8 @@ def upload_to_staging_hf( def promote_staging_to_production_hf( files: List[str], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, ) -> int: """ Atomically promote files from staging/ to production paths. @@ -455,8 +450,8 @@ def promote_staging_to_production_hf( def cleanup_staging_hf( files: List[str], version: str, - hf_repo_name: str = "policyengine/policyengine-us-data", - hf_repo_type: str = "model", + hf_repo_name: str = HF_REPO_NAME, + hf_repo_type: str = HF_REPO_TYPE, ) -> int: """ Clean up staging folder after successful promotion. diff --git a/policyengine_us_data/utils/version_manifest.py b/policyengine_us_data/utils/version_manifest.py index 37834ff39..b51df5868 100644 --- a/policyengine_us_data/utils/version_manifest.py +++ b/policyengine_us_data/utils/version_manifest.py @@ -17,6 +17,7 @@ from typing import Any, Optional import google.auth +from packaging.version import Version from google.api_core.exceptions import NotFound from google.cloud import storage from huggingface_hub import ( @@ -302,19 +303,14 @@ def _restore_hf_commit( commit_message=(f"Roll back to {target_version} as {new_version}"), ) - try: - api.create_tag( - token=token, - repo_id=HF_REPO_NAME, - tag=new_version, - revision=commit_info.oid, - repo_type=HF_REPO_TYPE, - ) - except Exception as e: - if "already exists" in str(e) or "409" in str(e): - logging.warning(f"Tag {new_version} already exists. Skipping tag creation.") - else: - raise + api.create_tag( + token=token, + repo_id=HF_REPO_NAME, + tag=new_version, + revision=commit_info.oid, + repo_type=HF_REPO_TYPE, + exist_ok=True, + ) return commit_info.oid @@ -420,7 +416,10 @@ def list_versions() -> list[str]: """ bucket = _get_gcs_bucket() registry = _read_registry_from_gcs(bucket) - return sorted(v.version for v in registry.versions) + return sorted( + (v.version for v in registry.versions), + key=Version, + ) def download_versioned_file( diff --git a/pyproject.toml b/pyproject.toml index 46e23bfaf..1a15adaee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "xlrd>=2.0.2", "spm-calculator>=0.1.0", "tenacity>=8.0.0", + "packaging>=21.0", ] [project.optional-dependencies] From 32f0d639d887b73876cef082df4d21d5b778cbd7 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 14 Mar 2026 02:01:54 +0100 Subject: [PATCH 8/8] Update uv.lock for packaging dependency Co-Authored-By: Claude Opus 4.6 --- uv.lock | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/uv.lock b/uv.lock index e554b94ef..92492e51e 100644 --- a/uv.lock +++ b/uv.lock @@ -610,6 +610,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, + { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -617,6 +618,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, + { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -1820,6 +1822,7 @@ dependencies = [ { name = "microdf-python" }, { name = "microimpute" }, { name = "openpyxl" }, + { name = "packaging" }, { name = "pandas" }, { name = "pip-system-certs" }, { name = "policyengine-core" }, @@ -1871,6 +1874,7 @@ requires-dist = [ { name = "microdf-python", specifier = ">=1.2.1" }, { name = "microimpute", specifier = ">=1.15.1" }, { name = "openpyxl", specifier = ">=3.1.5" }, + { name = "packaging", specifier = ">=21.0" }, { name = "pandas", specifier = ">=2.3.1" }, { name = "pip-system-certs", specifier = ">=3.0" }, { name = "policyengine-core", specifier = ">=3.23.6" },