Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/allow-unmanaged-local-datasets.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow unmanaged local dataset paths when callers explicitly opt out of release-bundle enforcement.
28 changes: 28 additions & 0 deletions src/policyengine/provenance/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,23 @@ def resolve_dataset_reference(country_id: str, dataset: str) -> str:
return artifact.uri


def _existing_local_dataset_path(dataset: str) -> Optional[Path]:
path = Path(dataset).expanduser()
if not path.exists():
return None

is_path_like = (
path.is_absolute()
or dataset.startswith(("~", "."))
or os.sep in dataset
or (os.altsep is not None and os.altsep in dataset)
or path.suffix.lower() in {".h5", ".hdf5"}
)
if not is_path_like:
return None
return path.resolve()


def resolve_managed_dataset_reference(
country_id: str,
dataset: Optional[str] = None,
Expand Down Expand Up @@ -472,6 +489,17 @@ def resolve_managed_dataset_reference(
"bypass bundle enforcement."
)

local_dataset_path = _existing_local_dataset_path(dataset)
if local_dataset_path is not None:
if allow_unmanaged:
return str(local_dataset_path)
raise ValueError(
"Local dataset paths bypass the policyengine.py release bundle. "
"Pass a manifest dataset name or omit `dataset` to use the certified "
"default dataset. Set `allow_unmanaged=True` only if you intend to "
"run against a local dataset outside the bundle."
)

return resolve_dataset_reference(country_id, dataset)


Expand Down
42 changes: 42 additions & 0 deletions tests/test_release_manifests.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,29 @@ def test__given_explicit_uri__then_managed_resolution_requires_opt_in(self):
== dataset
)

def test__given_local_dataset_path__then_managed_resolution_requires_opt_in(
self,
tmp_path,
):
dataset_path = tmp_path / "local_2100.h5"
dataset_path.write_bytes(b"not a real h5; resolution only")

try:
resolve_managed_dataset_reference("us", str(dataset_path))
except ValueError as error:
assert (
"Local dataset paths bypass the policyengine.py release bundle"
in str(error)
)
else:
raise AssertionError("Expected local dataset path to be rejected")

assert resolve_managed_dataset_reference(
"us",
str(dataset_path),
allow_unmanaged=True,
) == str(dataset_path.resolve())

def test__given_versioned_dataset_url__then_logical_name_drops_version(self):
dataset = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.73.0"

Expand Down Expand Up @@ -633,6 +656,25 @@ def test__given_us_unmanaged_dataset_uri__then_source_is_not_rewritten(self):
assert microsim.policyengine_bundle["runtime_dataset_uri"] == dataset
assert microsim.policyengine_bundle["runtime_dataset_source"] == dataset

def test__given_us_unmanaged_local_dataset__then_source_is_local_path(
self,
tmp_path,
):
dataset_path = tmp_path / "local_2100.h5"
dataset_path.write_bytes(b"not a real h5; source plumbing only")

with patch("policyengine_us.Microsimulation") as mock_microsimulation:
microsim = managed_us_microsimulation(
dataset=str(dataset_path),
allow_unmanaged=True,
)

resolved_path = str(dataset_path.resolve())
assert mock_microsimulation.call_args.kwargs["dataset"] == resolved_path
assert microsim.policyengine_bundle["runtime_dataset"] == "local_2100"
assert microsim.policyengine_bundle["runtime_dataset_uri"] == resolved_path
assert microsim.policyengine_bundle["runtime_dataset_source"] == resolved_path

def test__given_uk_managed_dataset_name__then_resolves_within_bundle(self):
with patch("policyengine_uk.Microsimulation") as mock_microsimulation:
microsim = managed_uk_microsimulation(dataset="enhanced_frs_2023_24")
Expand Down
Loading