From 4eacf7112ba2aab88ee3418633bb397d097a548c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Mar 2026 21:15:58 +0100 Subject: [PATCH 1/8] Add entity-level HDFStore output format alongside h5py The stacked_dataset_builder now produces a Pandas HDFStore file (.hdfstore.h5) in addition to the existing h5py file. The HDFStore contains one table per entity (person, household, tax_unit, spm_unit, family, marital_unit) plus an embedded _variable_metadata manifest recording each variable's entity and uprating parameter path. The upload pipeline uploads HDFStore files to dedicated subdirectories (states_hdfstore/, districts_hdfstore/, cities_hdfstore/). A comparison test (test_format_comparison.py) validates that both formats contain identical data for all variables. Co-Authored-By: Claude Opus 4.6 --- .../publish_local_area.py | 32 ++ .../stacked_dataset_builder.py | 164 ++++++++++ .../tests/test_format_comparison.py | 285 ++++++++++++++++++ 3 files changed, 481 insertions(+) create mode 100644 policyengine_us_data/tests/test_format_comparison.py diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py index 4963f397..42bfd1b7 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py @@ -280,8 +280,18 @@ def build_and_upload_states( print(f"Uploading {state_code}.h5 to GCP...") upload_local_area_file(str(output_path), "states", skip_hf=True) + # Upload HDFStore file if it exists + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") + if os.path.exists(hdfstore_path): + print(f"Uploading {state_code}.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "states_hdfstore", skip_hf=True + ) + # Queue for batched HuggingFace upload hf_queue.append((str(output_path), "states")) + if os.path.exists(hdfstore_path): + hf_queue.append((hdfstore_path, "states_hdfstore")) record_completed_state(state_code) print(f"Completed {state_code}") @@ -352,8 +362,18 @@ def build_and_upload_districts( print(f"Uploading {friendly_name}.h5 to GCP...") upload_local_area_file(str(output_path), "districts", skip_hf=True) + # Upload HDFStore file if it exists + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") + if os.path.exists(hdfstore_path): + print(f"Uploading {friendly_name}.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "districts_hdfstore", skip_hf=True + ) + # Queue for batched HuggingFace upload hf_queue.append((str(output_path), "districts")) + if os.path.exists(hdfstore_path): + hf_queue.append((hdfstore_path, "districts_hdfstore")) record_completed_district(friendly_name) print(f"Completed {friendly_name}") @@ -424,8 +444,20 @@ def build_and_upload_cities( str(output_path), "cities", skip_hf=True ) + # Upload HDFStore file if it exists + hdfstore_path = str(output_path).replace( + ".h5", ".hdfstore.h5" + ) + if os.path.exists(hdfstore_path): + print("Uploading NYC.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "cities_hdfstore", skip_hf=True + ) + # Queue for batched HuggingFace upload hf_queue.append((str(output_path), "cities")) + if os.path.exists(hdfstore_path): + hf_queue.append((hdfstore_path, "cities_hdfstore")) record_completed_city("NYC") print("Completed NYC") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py index 010e151f..c4d449ce 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py @@ -59,6 +59,156 @@ def get_county_name(county_index: int) -> str: return County._member_names_[county_index] +def _split_into_entity_dfs(combined_df, system, vars_to_save, time_period): + """Split person-level DataFrame into entity-level DataFrames. + + The combined_df has columns named ``variable__period`` (e.g. + ``employment_income__2024``). This function strips the period suffix, + classifies each variable by entity, and returns one DataFrame per + entity with clean column names. + + For group entities the rows are deduplicated by entity ID so that each + entity appears exactly once. + """ + + suffix = f"__{time_period}" + + # Build a mapping from clean variable name -> column in combined_df + col_map = {} + for col in combined_df.columns: + if col.endswith(suffix): + clean = col[: -len(suffix)] + col_map[clean] = col + + # Entity classification buckets + ENTITIES = [ + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + ] + entity_cols = {e: [] for e in ENTITIES} + + # Person-level entity membership ID columns (person_household_id, etc.) + person_ref_cols = [] + + for var in sorted(vars_to_save): + if var not in col_map: + continue + if var in system.variables: + entity_key = system.variables[var].entity.key + entity_cols[entity_key].append(var) + else: + # Geography/custom vars without system entry go to household + entity_cols["household"].append(var) + + # --- Person DataFrame --- + person_vars = ["person_id"] + entity_cols["person"] + # Add person-level entity membership columns + for entity in ENTITIES[1:]: # skip person + ref_col = f"person_{entity}_id" + if ref_col in col_map: + person_vars.append(ref_col) + person_ref_cols.append(ref_col) + + person_src_cols = [col_map[v] for v in person_vars if v in col_map] + person_df = combined_df[person_src_cols].copy() + person_df.columns = [ + c[: -len(suffix)] if c.endswith(suffix) else c + for c in person_df.columns + ] + + entity_dfs = {"person": person_df} + + # --- Group entity DataFrames: deduplicate by entity ID --- + for entity in ENTITIES[1:]: + id_col = f"{entity}_id" + person_ref = f"person_{entity}_id" + # Use person_ref column if available, else id_col + src_id = person_ref if person_ref in col_map else id_col + + if src_id not in col_map: + continue + + # Collect columns for this entity + cols_to_use = [src_id] + [ + v for v in entity_cols[entity] if v != id_col and v in col_map + ] + src_cols = [col_map[v] for v in cols_to_use] + df = combined_df[src_cols].copy() + # Strip period suffix + df.columns = [ + c[: -len(suffix)] if c.endswith(suffix) else c + for c in df.columns + ] + # Rename person_X_id -> X_id if needed + if src_id == person_ref and person_ref != id_col: + df = df.rename(columns={person_ref: id_col}) + # Deduplicate + df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) + entity_dfs[entity] = df + + return entity_dfs + + +def _build_uprating_manifest(vars_to_save, system): + """Build manifest of variable metadata for embedding in HDFStore.""" + records = [] + for var in sorted(vars_to_save): + entity = ( + system.variables[var].entity.key + if var in system.variables + else "unknown" + ) + uprating = "" + if var in system.variables: + uprating = getattr(system.variables[var], "uprating", None) or "" + records.append( + {"variable": var, "entity": entity, "uprating": uprating} + ) + return pd.DataFrame(records) + + +def _save_hdfstore(entity_dfs, manifest_df, output_path, time_period): + """Save entity DataFrames and manifest to a Pandas HDFStore file.""" + import warnings + + hdfstore_path = output_path.replace(".h5", ".hdfstore.h5") + + print(f"\nSaving HDFStore to {hdfstore_path}...") + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + message=".*PyTables will pickle object types.*", + ) + with pd.HDFStore(hdfstore_path, mode="w") as store: + for entity_name, df in entity_dfs.items(): + # Convert object columns to string for HDFStore compatibility + for col in df.columns: + if df[col].dtype == object: + df[col] = df[col].astype(str) + store.put(entity_name, df, format="table") + + store.put("_variable_metadata", manifest_df, format="table") + store.put( + "_time_period", + pd.Series([time_period]), + format="table", + ) + + # Print summary + for entity_name, df in entity_dfs.items(): + print(f" {entity_name}: {len(df):,} rows, {len(df.columns)} cols") + print(f" manifest: {len(manifest_df)} variables") + + print(f"HDFStore saved successfully!") + return hdfstore_path + + def create_sparse_cd_stacked_dataset( w, cds_to_calibrate, @@ -738,6 +888,20 @@ def create_sparse_cd_stacked_dataset( f" Average persons per household: {np.sum(person_weights) / np.sum(weights):.2f}" ) + # --- HDFStore output (entity-level format) --- + # Split the person-level combined_df into per-entity DataFrames and save + # alongside the h5py file. This format is consumed by the API v2 alpha + # and by policyengine-us's extend_single_year_dataset(). + entity_dfs = _split_into_entity_dfs( + combined_df, base_sim.tax_benefit_system, vars_to_save, time_period + ) + manifest_df = _build_uprating_manifest( + vars_to_save, base_sim.tax_benefit_system + ) + hdfstore_path = _save_hdfstore( + entity_dfs, manifest_df, output_path, time_period + ) + return output_path diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py new file mode 100644 index 00000000..4741d41e --- /dev/null +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -0,0 +1,285 @@ +""" +Compare h5py (variable-centric) and HDFStore (entity-level) output formats. + +Verifies that both formats produced by stacked_dataset_builder contain +identical data for all variables. + +Usage as pytest: + pytest test_format_comparison.py --h5py-path path/to/STATE.h5 \ + --hdfstore-path path/to/STATE.hdfstore.h5 + +Usage as standalone script: + python -m policyengine_us_data.tests.test_format_comparison \ + --h5py-path path/to/STATE.h5 \ + --hdfstore-path path/to/STATE.hdfstore.h5 +""" + +import argparse +import sys + +import h5py +import numpy as np +import pandas as pd +import pytest + + +def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: + """Compare all variables between h5py and HDFStore formats. + + Returns a dict with keys: passed, failed, skipped, details. + """ + passed = [] + failed = [] + skipped = [] + + with h5py.File(h5py_path, "r") as f: + h5_vars = sorted(f.keys()) + # Get the year from the first variable's subkeys + first_var = h5_vars[0] + year = list(f[first_var].keys())[0] + + with pd.HDFStore(hdfstore_path, "r") as store: + # Load all entity DataFrames + store_keys = [k for k in store.keys() if not k.startswith("/_")] + entity_dfs = {k: store[k] for k in store_keys} + + # Load manifest + manifest = None + if "/_variable_metadata" in store.keys(): + manifest = store["/_variable_metadata"] + + for var in h5_vars: + h5_values = f[var][year][:] + + # Find which entity DataFrame contains this variable + found = False + for entity_key, df in entity_dfs.items(): + entity_name = entity_key.lstrip("/") + if var in df.columns: + hdf_values = df[var].values + + # For person-level variables, arrays should be + # same length and directly comparable (both are + # ordered by row index from combined_df). + # For group entities, the h5py array is at person + # level while HDFStore is deduplicated. We need + # to handle this difference. + if entity_name != "person" and len(hdf_values) != len( + h5_values + ): + # h5py stores at person level; HDFStore is + # deduplicated by entity ID. We can't do a + # direct comparison — verify unique values match. + h5_unique = np.unique(h5_values) + hdf_unique = np.unique(hdf_values) + if h5_values.dtype.kind in ("U", "S", "O"): + match = set(h5_unique) == set(hdf_unique) + else: + match = np.allclose( + np.sort(h5_unique.astype(float)), + np.sort(hdf_unique.astype(float)), + rtol=1e-5, + equal_nan=True, + ) + if match: + passed.append(var) + else: + failed.append( + ( + var, + f"unique values differ " + f"(h5py: {len(h5_unique)}, " + f"hdfstore: {len(hdf_unique)})", + ) + ) + else: + # Same length — direct comparison + if h5_values.dtype.kind in ("U", "S", "O"): + # String comparison + h5_str = np.array( + [ + ( + x.decode() + if isinstance(x, bytes) + else str(x) + ) + for x in h5_values + ] + ) + hdf_str = np.array( + [str(x) for x in hdf_values] + ) + if np.array_equal(h5_str, hdf_str): + passed.append(var) + else: + mismatches = np.sum(h5_str != hdf_str) + failed.append( + ( + var, + f"{mismatches} string mismatches", + ) + ) + else: + # Numeric comparison + h5_float = h5_values.astype(float) + hdf_float = hdf_values.astype(float) + if np.allclose( + h5_float, + hdf_float, + rtol=1e-5, + equal_nan=True, + ): + passed.append(var) + else: + diff = np.abs(h5_float - hdf_float) + max_diff = np.max(diff) + n_diff = np.sum( + ~np.isclose( + h5_float, + hdf_float, + rtol=1e-5, + equal_nan=True, + ) + ) + failed.append( + ( + var, + f"{n_diff} values differ, " + f"max diff={max_diff:.6f}", + ) + ) + found = True + break + + if not found: + skipped.append(var) + + return { + "passed": passed, + "failed": failed, + "skipped": skipped, + "total_h5py_vars": len(h5_vars), + } + + +def pytest_addoption(parser): + parser.addoption("--h5py-path", action="store", default=None) + parser.addoption("--hdfstore-path", action="store", default=None) + + +@pytest.fixture +def h5py_path(request): + path = request.config.getoption("--h5py-path") + if path is None: + pytest.skip("--h5py-path not provided") + return path + + +@pytest.fixture +def hdfstore_path(request): + path = request.config.getoption("--hdfstore-path") + if path is None: + pytest.skip("--hdfstore-path not provided") + return path + + +def test_formats_match(h5py_path, hdfstore_path): + """Verify h5py and HDFStore formats contain identical data.""" + result = compare_formats(h5py_path, hdfstore_path) + + print(f"\n{'='*60}") + print(f"Format Comparison Results") + print(f"{'='*60}") + print(f"Total h5py variables: {result['total_h5py_vars']}") + print(f"Passed: {len(result['passed'])}") + print(f"Failed: {len(result['failed'])}") + print(f"Skipped (not in HDFStore): {len(result['skipped'])}") + + if result["failed"]: + print(f"\nFailed variables:") + for var, reason in result["failed"]: + print(f" {var}: {reason}") + + if result["skipped"]: + print(f"\nSkipped variables (not found in HDFStore):") + for var in result["skipped"]: + print(f" {var}") + + assert len(result["failed"]) == 0, ( + f"{len(result['failed'])} variables have mismatched values" + ) + assert len(result["skipped"]) == 0, ( + f"{len(result['skipped'])} variables missing from HDFStore" + ) + + +def test_manifest_present(hdfstore_path): + """Verify the HDFStore contains a variable metadata manifest.""" + with pd.HDFStore(hdfstore_path, "r") as store: + assert "/_variable_metadata" in store.keys(), ( + "Missing _variable_metadata table" + ) + manifest = store["/_variable_metadata"] + assert "variable" in manifest.columns + assert "entity" in manifest.columns + assert "uprating" in manifest.columns + assert len(manifest) > 0, "Manifest is empty" + print(f"\nManifest has {len(manifest)} variables") + print(f"Entities: {manifest['entity'].unique().tolist()}") + n_uprated = (manifest["uprating"] != "").sum() + print(f"Variables with uprating: {n_uprated}") + + +def test_all_entities_present(hdfstore_path): + """Verify the HDFStore contains all expected entity tables.""" + expected = {"person", "household", "tax_unit", "spm_unit", "family", "marital_unit"} + with pd.HDFStore(hdfstore_path, "r") as store: + actual = {k.lstrip("/") for k in store.keys() if not k.startswith("/_")} + missing = expected - actual + assert not missing, f"Missing entity tables: {missing}" + for entity in expected: + df = store[f"/{entity}"] + assert len(df) > 0, f"Entity {entity} has 0 rows" + assert f"{entity}_id" in df.columns, ( + f"Entity {entity} missing {entity}_id column" + ) + print(f" {entity}: {len(df):,} rows, {len(df.columns)} cols") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Compare h5py and HDFStore dataset formats" + ) + parser.add_argument( + "--h5py-path", required=True, help="Path to h5py format file" + ) + parser.add_argument( + "--hdfstore-path", required=True, help="Path to HDFStore format file" + ) + args = parser.parse_args() + + result = compare_formats(args.h5py_path, args.hdfstore_path) + + print(f"\n{'='*60}") + print(f"Format Comparison Results") + print(f"{'='*60}") + print(f"Total h5py variables: {result['total_h5py_vars']}") + print(f"Passed: {len(result['passed'])}") + print(f"Failed: {len(result['failed'])}") + print(f"Skipped (not in HDFStore): {len(result['skipped'])}") + + if result["failed"]: + print(f"\nFailed variables:") + for var, reason in result["failed"]: + print(f" {var}: {reason}") + + if result["skipped"]: + print(f"\nSkipped variables (not found in HDFStore):") + for var in result["skipped"]: + print(f" {var}") + + if result["failed"] or result["skipped"]: + sys.exit(1) + else: + print("\nAll variables match!") + sys.exit(0) From 16eeacf7b0e26da063df4e30c317a66357dc5b01 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Mar 2026 17:47:22 +0100 Subject: [PATCH 2/8] Rewrite format comparison as one-off validation script Replaces the two-file-input test with a self-contained roundtrip script that takes only an h5py file path, generates an HDFStore using inlined splitting logic, then compares both formats. Handles entity-level h5py files and yearly/ETERNITY/monthly period keys. Co-Authored-By: Claude Opus 4.6 --- .../tests/test_format_comparison.py | 403 ++++++++++++++---- 1 file changed, 316 insertions(+), 87 deletions(-) diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py index 4741d41e..3bc2005e 100644 --- a/policyengine_us_data/tests/test_format_comparison.py +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -1,21 +1,24 @@ """ -Compare h5py (variable-centric) and HDFStore (entity-level) output formats. +ONE-OFF VALIDATION SCRIPT -Verifies that both formats produced by stacked_dataset_builder contain -identical data for all variables. +This is a one-off script used to verify that the h5py-to-HDFStore +conversion logic is correct. It reads an existing h5py dataset file, +converts it to entity-level Pandas HDFStore using the same splitting/dedup +logic as stacked_dataset_builder, then compares all variables to verify +the conversion is lossless. -Usage as pytest: - pytest test_format_comparison.py --h5py-path path/to/STATE.h5 \ - --hdfstore-path path/to/STATE.hdfstore.h5 +This script is NOT part of the regular test suite and is not intended to +be run in CI. It exists to validate the HDFStore serialization logic +during development. -Usage as standalone script: - python -m policyengine_us_data.tests.test_format_comparison \ - --h5py-path path/to/STATE.h5 \ - --hdfstore-path path/to/STATE.hdfstore.h5 +Usage (run directly to avoid policyengine_us_data __init__ imports): + python policyengine_us_data/tests/test_format_comparison.py \ + --h5py-path path/to/STATE.h5 """ import argparse import sys +import warnings import h5py import numpy as np @@ -23,10 +26,219 @@ import pytest +ENTITIES = [ + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", +] + + +def _load_system(): + """Load the policyengine-us tax-benefit system.""" + from policyengine_us import system as us_system + + return us_system.system + + +# --------------------------------------------------------------------------- +# h5py -> HDFStore conversion (self-contained reproduction of the builder +# logic so we don't need to import stacked_dataset_builder and its heavy deps) +# --------------------------------------------------------------------------- + + +def _read_h5py_arrays(h5py_path: str): + """Read all arrays from an h5py variable-centric file. + + The h5py format stores ``variable / period -> array``. Periods can be + yearly (``"2024"``), monthly (``"2024-01"``), or ``"ETERNITY"``. + + Some h5py files are fully person-level (all arrays have the same length). + Others are already entity-level: group-entity variables have fewer rows + than person-level variables. + + Returns ``(arrays, time_period, h5_vars)`` where arrays is a dict of + ``{variable_name: numpy_array}``. + """ + with h5py.File(h5py_path, "r") as f: + h5_vars = sorted(f.keys()) + + # Determine the canonical year from the first variable that has one + year = None + for var in h5_vars: + subkeys = list(f[var].keys()) + for sk in subkeys: + if sk.isdigit() and len(sk) == 4: + year = sk + break + if year is not None: + break + if year is None: + raise ValueError("Could not determine year from h5py file") + + time_period = int(year) + arrays = {} + + for var in h5_vars: + subkeys = list(f[var].keys()) + if year in subkeys: + period_key = year + elif "ETERNITY" in subkeys: + period_key = "ETERNITY" + else: + period_key = subkeys[0] + + arr = f[var][period_key][:] + if arr.dtype.kind in ("S", "O"): + arr = np.array( + [ + x.decode() if isinstance(x, bytes) else str(x) + for x in arr + ] + ) + arrays[var] = arr + + return arrays, time_period, h5_vars + + +def _split_into_entity_dfs(arrays, system, vars_to_save): + """Build entity-level DataFrames from a dict of variable arrays. + + ``arrays`` maps variable names to numpy arrays. Arrays may already be + at entity-level (different lengths for different entities) or all at + person-level. We group variables by entity, then build one DataFrame + per entity using arrays of matching length. + """ + entity_cols = {e: [] for e in ENTITIES} + + for var in sorted(vars_to_save): + if var not in arrays: + continue + if var in system.variables: + entity_key = system.variables[var].entity.key + entity_cols[entity_key].append(var) + else: + entity_cols["household"].append(var) + + # Person DataFrame: person vars + entity membership IDs + person_vars = entity_cols["person"][:] + if "person_id" not in person_vars and "person_id" in arrays: + person_vars.insert(0, "person_id") + for entity in ENTITIES[1:]: + ref_col = f"person_{entity}_id" + if ref_col in arrays: + person_vars.append(ref_col) + + person_df = pd.DataFrame({v: arrays[v] for v in person_vars if v in arrays}) + entity_dfs = {"person": person_df} + + # Group entity DataFrames + for entity in ENTITIES[1:]: + id_col = f"{entity}_id" + vars_for_entity = entity_cols[entity][:] + if id_col not in vars_for_entity and id_col in arrays: + vars_for_entity.insert(0, id_col) + + if not vars_for_entity: + continue + + # Check if the arrays are already at entity level (shorter than + # person) or at person level (same length as person_id) + n_persons = len(arrays.get("person_id", [])) + sample_len = len(arrays[vars_for_entity[0]]) + + df_data = {v: arrays[v] for v in vars_for_entity if v in arrays} + df = pd.DataFrame(df_data) + + if sample_len == n_persons and id_col in df.columns: + # Person-level: need to deduplicate by entity ID + df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) + + entity_dfs[entity] = df + + return entity_dfs + + +def _build_uprating_manifest(vars_to_save, system): + """Build manifest of variable metadata.""" + records = [] + for var in sorted(vars_to_save): + entity = ( + system.variables[var].entity.key + if var in system.variables + else "unknown" + ) + uprating = "" + if var in system.variables: + uprating = getattr(system.variables[var], "uprating", None) or "" + records.append( + {"variable": var, "entity": entity, "uprating": uprating} + ) + return pd.DataFrame(records) + + +def _save_hdfstore(entity_dfs, manifest_df, hdfstore_path, time_period): + """Save entity DataFrames and manifest to a Pandas HDFStore file.""" + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + message=".*PyTables will pickle object types.*", + ) + with pd.HDFStore(hdfstore_path, mode="w") as store: + for entity_name, df in entity_dfs.items(): + # Deduplicate column names (can happen if a var appears + # in multiple entity buckets) + df = df.loc[:, ~df.columns.duplicated()] + for col in df.columns: + if df[col].dtype == object: + df[col] = df[col].astype(str) + store.put(entity_name, df, format="table") + store.put("_variable_metadata", manifest_df, format="table") + store.put( + "_time_period", pd.Series([time_period]), format="table" + ) + return hdfstore_path + + +# --------------------------------------------------------------------------- +# Main conversion + comparison logic +# --------------------------------------------------------------------------- + + +def h5py_to_hdfstore(h5py_path: str, hdfstore_path: str) -> dict: + """Convert an h5py variable-centric file to entity-level HDFStore. + + Returns a summary dict with entity row counts. + """ + print("Loading policyengine-us system (this takes a minute)...") + system = _load_system() + + print("Reading h5py file...") + arrays, time_period, h5_vars = _read_h5py_arrays(h5py_path) + n_persons = len(arrays.get("person_id", [])) + print(f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}") + + print("Splitting into entity DataFrames...") + entity_dfs = _split_into_entity_dfs(arrays, system, h5_vars) + manifest_df = _build_uprating_manifest(h5_vars, system) + + print(f"Saving HDFStore to {hdfstore_path}...") + _save_hdfstore(entity_dfs, manifest_df, hdfstore_path, time_period) + + summary = {} + for entity_name, df in entity_dfs.items(): + summary[entity_name] = {"rows": len(df), "cols": len(df.columns)} + summary["manifest_vars"] = len(manifest_df) + return summary + + def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: - """Compare all variables between h5py and HDFStore formats. + """Compare all variables between h5py and generated HDFStore. - Returns a dict with keys: passed, failed, skipped, details. + Returns a dict with keys: passed, failed, skipped. """ passed = [] failed = [] @@ -34,46 +246,52 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: with h5py.File(h5py_path, "r") as f: h5_vars = sorted(f.keys()) - # Get the year from the first variable's subkeys - first_var = h5_vars[0] - year = list(f[first_var].keys())[0] + + # Determine the year + year = None + for var in h5_vars: + for sk in f[var].keys(): + if sk.isdigit() and len(sk) == 4: + year = sk + break + if year is not None: + break with pd.HDFStore(hdfstore_path, "r") as store: - # Load all entity DataFrames store_keys = [k for k in store.keys() if not k.startswith("/_")] entity_dfs = {k: store[k] for k in store_keys} - # Load manifest - manifest = None - if "/_variable_metadata" in store.keys(): - manifest = store["/_variable_metadata"] - for var in h5_vars: - h5_values = f[var][year][:] + subkeys = list(f[var].keys()) + if year in subkeys: + period_key = year + elif "ETERNITY" in subkeys: + period_key = "ETERNITY" + else: + period_key = subkeys[0] + + h5_values = f[var][period_key][:] - # Find which entity DataFrame contains this variable found = False for entity_key, df in entity_dfs.items(): entity_name = entity_key.lstrip("/") if var in df.columns: hdf_values = df[var].values - # For person-level variables, arrays should be - # same length and directly comparable (both are - # ordered by row index from combined_df). - # For group entities, the h5py array is at person - # level while HDFStore is deduplicated. We need - # to handle this difference. + # For group entities, h5py is person-level while + # HDFStore is deduplicated by entity ID. if entity_name != "person" and len(hdf_values) != len( h5_values ): - # h5py stores at person level; HDFStore is - # deduplicated by entity ID. We can't do a - # direct comparison — verify unique values match. h5_unique = np.unique(h5_values) hdf_unique = np.unique(hdf_values) if h5_values.dtype.kind in ("U", "S", "O"): - match = set(h5_unique) == set(hdf_unique) + match = set( + x.decode() + if isinstance(x, bytes) + else str(x) + for x in h5_unique + ) == set(str(x) for x in hdf_unique) else: match = np.allclose( np.sort(h5_unique.astype(float)), @@ -95,7 +313,6 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: else: # Same length — direct comparison if h5_values.dtype.kind in ("U", "S", "O"): - # String comparison h5_str = np.array( [ ( @@ -120,7 +337,6 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: ) ) else: - # Numeric comparison h5_float = h5_values.astype(float) hdf_float = hdf_values.astype(float) if np.allclose( @@ -162,9 +378,32 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: } +def print_results(result): + """Print comparison results to stdout.""" + print(f"\n{'='*60}") + print("Format Comparison Results") + print(f"{'='*60}") + print(f"Total h5py variables: {result['total_h5py_vars']}") + print(f"Passed: {len(result['passed'])}") + print(f"Failed: {len(result['failed'])}") + print(f"Skipped (not in HDFStore): {len(result['skipped'])}") + + if result["failed"]: + print("\nFailed variables:") + for var, reason in result["failed"]: + print(f" {var}: {reason}") + + if result["skipped"]: + print("\nSkipped variables (not found in HDFStore):") + for var in result["skipped"]: + print(f" {var}") + + +# --- pytest interface --- + + def pytest_addoption(parser): parser.addoption("--h5py-path", action="store", default=None) - parser.addoption("--hdfstore-path", action="store", default=None) @pytest.fixture @@ -175,35 +414,17 @@ def h5py_path(request): return path -@pytest.fixture -def hdfstore_path(request): - path = request.config.getoption("--hdfstore-path") - if path is None: - pytest.skip("--hdfstore-path not provided") - return path +def test_roundtrip(h5py_path, tmp_path): + """Convert h5py -> HDFStore -> compare all variables.""" + hdfstore_path = str(tmp_path / "test_output.hdfstore.h5") + summary = h5py_to_hdfstore(h5py_path, hdfstore_path) + for entity, info in summary.items(): + if isinstance(info, dict): + print(f" {entity}: {info['rows']:,} rows, {info['cols']} cols") -def test_formats_match(h5py_path, hdfstore_path): - """Verify h5py and HDFStore formats contain identical data.""" result = compare_formats(h5py_path, hdfstore_path) - - print(f"\n{'='*60}") - print(f"Format Comparison Results") - print(f"{'='*60}") - print(f"Total h5py variables: {result['total_h5py_vars']}") - print(f"Passed: {len(result['passed'])}") - print(f"Failed: {len(result['failed'])}") - print(f"Skipped (not in HDFStore): {len(result['skipped'])}") - - if result["failed"]: - print(f"\nFailed variables:") - for var, reason in result["failed"]: - print(f" {var}: {reason}") - - if result["skipped"]: - print(f"\nSkipped variables (not found in HDFStore):") - for var in result["skipped"]: - print(f" {var}") + print_results(result) assert len(result["failed"]) == 0, ( f"{len(result['failed'])} variables have mismatched values" @@ -213,8 +434,11 @@ def test_formats_match(h5py_path, hdfstore_path): ) -def test_manifest_present(hdfstore_path): - """Verify the HDFStore contains a variable metadata manifest.""" +def test_manifest(h5py_path, tmp_path): + """Verify the generated HDFStore contains a valid manifest.""" + hdfstore_path = str(tmp_path / "test_output.hdfstore.h5") + h5py_to_hdfstore(h5py_path, hdfstore_path) + with pd.HDFStore(hdfstore_path, "r") as store: assert "/_variable_metadata" in store.keys(), ( "Missing _variable_metadata table" @@ -230,11 +454,16 @@ def test_manifest_present(hdfstore_path): print(f"Variables with uprating: {n_uprated}") -def test_all_entities_present(hdfstore_path): - """Verify the HDFStore contains all expected entity tables.""" - expected = {"person", "household", "tax_unit", "spm_unit", "family", "marital_unit"} +def test_all_entities(h5py_path, tmp_path): + """Verify the generated HDFStore contains all expected entity tables.""" + hdfstore_path = str(tmp_path / "test_output.hdfstore.h5") + h5py_to_hdfstore(h5py_path, hdfstore_path) + + expected = set(ENTITIES) with pd.HDFStore(hdfstore_path, "r") as store: - actual = {k.lstrip("/") for k in store.keys() if not k.startswith("/_")} + actual = { + k.lstrip("/") for k in store.keys() if not k.startswith("/_") + } missing = expected - actual assert not missing, f"Missing entity tables: {missing}" for entity in expected: @@ -246,37 +475,37 @@ def test_all_entities_present(hdfstore_path): print(f" {entity}: {len(df):,} rows, {len(df.columns)} cols") +# --- CLI interface --- + + if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Compare h5py and HDFStore dataset formats" + description="Convert h5py dataset to HDFStore and verify roundtrip" ) parser.add_argument( "--h5py-path", required=True, help="Path to h5py format file" ) parser.add_argument( - "--hdfstore-path", required=True, help="Path to HDFStore format file" + "--output-path", + default=None, + help="Path for generated HDFStore (default: alongside input file)", ) args = parser.parse_args() - result = compare_formats(args.h5py_path, args.hdfstore_path) - - print(f"\n{'='*60}") - print(f"Format Comparison Results") - print(f"{'='*60}") - print(f"Total h5py variables: {result['total_h5py_vars']}") - print(f"Passed: {len(result['passed'])}") - print(f"Failed: {len(result['failed'])}") - print(f"Skipped (not in HDFStore): {len(result['skipped'])}") + if args.output_path: + hdfstore_path = args.output_path + else: + hdfstore_path = args.h5py_path.replace(".h5", ".hdfstore.h5") - if result["failed"]: - print(f"\nFailed variables:") - for var, reason in result["failed"]: - print(f" {var}: {reason}") + print(f"Converting {args.h5py_path} -> {hdfstore_path}...") + summary = h5py_to_hdfstore(args.h5py_path, hdfstore_path) + for entity, info in summary.items(): + if isinstance(info, dict): + print(f" {entity}: {info['rows']:,} rows, {info['cols']} cols") - if result["skipped"]: - print(f"\nSkipped variables (not found in HDFStore):") - for var in result["skipped"]: - print(f" {var}") + print("\nComparing formats...") + result = compare_formats(args.h5py_path, hdfstore_path) + print_results(result) if result["failed"] or result["skipped"]: sys.exit(1) From 82b88fd1dae23bdf65ff04f2090914f23b179977 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Mar 2026 17:56:10 +0100 Subject: [PATCH 3/8] style: Run black formatter on changed files Co-Authored-By: Claude Opus 4.6 --- .../publish_local_area.py | 4 +- .../stacked_dataset_builder.py | 3 +- .../tests/test_format_comparison.py | 45 ++++++++++--------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py index 42bfd1b7..9f6cbbac 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py @@ -445,9 +445,7 @@ def build_and_upload_cities( ) # Upload HDFStore file if it exists - hdfstore_path = str(output_path).replace( - ".h5", ".hdfstore.h5" - ) + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") if os.path.exists(hdfstore_path): print("Uploading NYC.hdfstore.h5 to GCP...") upload_local_area_file( diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py index c4d449ce..8257a914 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py @@ -140,8 +140,7 @@ def _split_into_entity_dfs(combined_df, system, vars_to_save, time_period): df = combined_df[src_cols].copy() # Strip period suffix df.columns = [ - c[: -len(suffix)] if c.endswith(suffix) else c - for c in df.columns + c[: -len(suffix)] if c.endswith(suffix) else c for c in df.columns ] # Rename person_X_id -> X_id if needed if src_id == person_ref and person_ref != id_col: diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py index 3bc2005e..e69f383c 100644 --- a/policyengine_us_data/tests/test_format_comparison.py +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -25,7 +25,6 @@ import pandas as pd import pytest - ENTITIES = [ "person", "household", @@ -131,7 +130,9 @@ def _split_into_entity_dfs(arrays, system, vars_to_save): if ref_col in arrays: person_vars.append(ref_col) - person_df = pd.DataFrame({v: arrays[v] for v in person_vars if v in arrays}) + person_df = pd.DataFrame( + {v: arrays[v] for v in person_vars if v in arrays} + ) entity_dfs = {"person": person_df} # Group entity DataFrames @@ -197,9 +198,7 @@ def _save_hdfstore(entity_dfs, manifest_df, hdfstore_path, time_period): df[col] = df[col].astype(str) store.put(entity_name, df, format="table") store.put("_variable_metadata", manifest_df, format="table") - store.put( - "_time_period", pd.Series([time_period]), format="table" - ) + store.put("_time_period", pd.Series([time_period]), format="table") return hdfstore_path @@ -219,7 +218,9 @@ def h5py_to_hdfstore(h5py_path: str, hdfstore_path: str) -> dict: print("Reading h5py file...") arrays, time_period, h5_vars = _read_h5py_arrays(h5py_path) n_persons = len(arrays.get("person_id", [])) - print(f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}") + print( + f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}" + ) print("Splitting into entity DataFrames...") entity_dfs = _split_into_entity_dfs(arrays, system, h5_vars) @@ -287,9 +288,11 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: hdf_unique = np.unique(hdf_values) if h5_values.dtype.kind in ("U", "S", "O"): match = set( - x.decode() - if isinstance(x, bytes) - else str(x) + ( + x.decode() + if isinstance(x, bytes) + else str(x) + ) for x in h5_unique ) == set(str(x) for x in hdf_unique) else: @@ -426,12 +429,12 @@ def test_roundtrip(h5py_path, tmp_path): result = compare_formats(h5py_path, hdfstore_path) print_results(result) - assert len(result["failed"]) == 0, ( - f"{len(result['failed'])} variables have mismatched values" - ) - assert len(result["skipped"]) == 0, ( - f"{len(result['skipped'])} variables missing from HDFStore" - ) + assert ( + len(result["failed"]) == 0 + ), f"{len(result['failed'])} variables have mismatched values" + assert ( + len(result["skipped"]) == 0 + ), f"{len(result['skipped'])} variables missing from HDFStore" def test_manifest(h5py_path, tmp_path): @@ -440,9 +443,9 @@ def test_manifest(h5py_path, tmp_path): h5py_to_hdfstore(h5py_path, hdfstore_path) with pd.HDFStore(hdfstore_path, "r") as store: - assert "/_variable_metadata" in store.keys(), ( - "Missing _variable_metadata table" - ) + assert ( + "/_variable_metadata" in store.keys() + ), "Missing _variable_metadata table" manifest = store["/_variable_metadata"] assert "variable" in manifest.columns assert "entity" in manifest.columns @@ -469,9 +472,9 @@ def test_all_entities(h5py_path, tmp_path): for entity in expected: df = store[f"/{entity}"] assert len(df) > 0, f"Entity {entity} has 0 rows" - assert f"{entity}_id" in df.columns, ( - f"Entity {entity} missing {entity}_id column" - ) + assert ( + f"{entity}_id" in df.columns + ), f"Entity {entity} missing {entity}_id column" print(f" {entity}: {len(df):,} rows, {len(df.columns)} cols") From 1c123050d81c3c564f47ca37abe142d4bfd57f7e Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 12 Mar 2026 16:56:09 -0400 Subject: [PATCH 4/8] Port HDFStore helper functions to new build_h5() structure Adapts _split_into_entity_dfs, _build_uprating_manifest, and _save_hdfstore to work with the new data dict (var -> {period -> array}) instead of the old combined_df DataFrame. Adds the integration call at the end of build_h5() so HDFStore files are generated alongside h5py. Co-Authored-By: Claude Opus 4.6 --- .../calibration/publish_local_area.py | 129 +++++++++++++++++- 1 file changed, 128 insertions(+), 1 deletion(-) diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index cbacfd08..6816b723 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -9,9 +9,11 @@ """ import os +import warnings import numpy as np +import pandas as pd from pathlib import Path -from typing import List +from typing import Dict, List from policyengine_us import Microsimulation from policyengine_us_data.utils.huggingface import download_calibration_inputs @@ -106,6 +108,126 @@ def record_completed_city(city_name: str): f.write(f"{city_name}\n") +def _split_data_into_entity_dfs( + data: Dict[str, dict], + system, + time_period: int, +) -> Dict[str, pd.DataFrame]: + """Split the data dict into per-entity DataFrames. + + Groups variables by entity, builds one DataFrame per entity. + Group entities are deduplicated by their ID column. + """ + ENTITIES = [ + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + ] + entity_vars: Dict[str, list] = {e: [] for e in ENTITIES} + + for var_name in sorted(data.keys()): + if var_name in system.variables: + ek = system.variables[var_name].entity.key + if ek in entity_vars: + entity_vars[ek].append(var_name) + else: + entity_vars["household"].append(var_name) + + entity_dfs: Dict[str, pd.DataFrame] = {} + for entity in ENTITIES: + id_col = f"{entity}_id" + cols = {} + for var_name in entity_vars[entity]: + periods = data[var_name] + tp_key = time_period if time_period in periods else str(time_period) + if tp_key not in periods: + continue + arr = periods[tp_key] + if hasattr(arr, "dtype") and arr.dtype.kind == "S": + arr = np.char.decode(arr, "utf-8") + cols[var_name] = arr + + if entity == "person": + for ref_entity in ENTITIES[1:]: + ref_col = f"person_{ref_entity}_id" + if ref_col in data: + periods = data[ref_col] + tp_key = time_period if time_period in periods else str(time_period) + if tp_key in periods: + cols[ref_col] = periods[tp_key] + + if not cols: + continue + + df = pd.DataFrame(cols) + if entity != "person" and id_col in df.columns: + df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) + entity_dfs[entity] = df + + return entity_dfs + + +def _build_uprating_manifest( + data: Dict[str, dict], + system, +) -> pd.DataFrame: + """Build manifest of variable metadata for embedding in HDFStore.""" + records = [] + for var_name in sorted(data.keys()): + entity = ( + system.variables[var_name].entity.key + if var_name in system.variables + else "unknown" + ) + uprating = "" + if var_name in system.variables: + uprating = getattr(system.variables[var_name], "uprating", None) or "" + records.append({"variable": var_name, "entity": entity, "uprating": uprating}) + return pd.DataFrame(records) + + +def _save_hdfstore( + entity_dfs: Dict[str, pd.DataFrame], + manifest_df: pd.DataFrame, + output_path: str, + time_period: int, +) -> str: + """Save entity DataFrames and manifest to a Pandas HDFStore file.""" + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") + + print(f"\nSaving HDFStore to {hdfstore_path}...") + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + message=".*PyTables will pickle object types.*", + ) + with pd.HDFStore(hdfstore_path, mode="w") as store: + for entity_name, df in entity_dfs.items(): + for col in df.columns: + if df[col].dtype == object: + df[col] = df[col].astype(str) + store.put(entity_name, df, format="table") + + store.put("_variable_metadata", manifest_df, format="table") + store.put( + "_time_period", + pd.Series([time_period]), + format="table", + ) + + for entity_name, df in entity_dfs.items(): + print(f" {entity_name}: {len(df):,} rows, {len(df.columns)} cols") + print(f" manifest: {len(manifest_df)} variables") + print("HDFStore saved successfully!") + + return hdfstore_path + + def build_h5( weights: np.ndarray, geography, @@ -564,6 +686,11 @@ def build_h5( pw = f["person_weight"][tp][:] print(f"Total population (person weights): {pw.sum():,.0f}") + # === HDFStore output (entity-level format) === + entity_dfs = _split_data_into_entity_dfs(data, sim.tax_benefit_system, time_period) + manifest_df = _build_uprating_manifest(data, sim.tax_benefit_system) + _save_hdfstore(entity_dfs, manifest_df, str(output_path), time_period) + return output_path From 11bf0939546709a6d20c4c159ab6d451061589bf Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 16 Mar 2026 20:00:25 +0100 Subject: [PATCH 5/8] Extract HDFStore logic into shared utils/hdfstore.py module Moves split_data_into_entity_dfs, build_uprating_manifest, and save_hdfstore out of publish_local_area.py into a standalone utility module. Updates test_format_comparison.py to import the production functions instead of reimplementing them, bridging the h5py flat-array format into the nested {var: {period: array}} structure expected by the shared module. Co-Authored-By: Claude Opus 4.6 --- .../calibration/publish_local_area.py | 132 +------------ .../tests/test_format_comparison.py | 149 +++------------ policyengine_us_data/utils/hdfstore.py | 175 ++++++++++++++++++ 3 files changed, 210 insertions(+), 246 deletions(-) create mode 100644 policyengine_us_data/utils/hdfstore.py diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index 6816b723..ad3ccb1a 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -9,13 +9,17 @@ """ import os -import warnings import numpy as np import pandas as pd from pathlib import Path from typing import Dict, List from policyengine_us import Microsimulation +from policyengine_us_data.utils.hdfstore import ( + split_data_into_entity_dfs, + build_uprating_manifest, + save_hdfstore, +) from policyengine_us_data.utils.huggingface import download_calibration_inputs from policyengine_us_data.utils.data_upload import ( upload_local_area_file, @@ -108,126 +112,6 @@ def record_completed_city(city_name: str): f.write(f"{city_name}\n") -def _split_data_into_entity_dfs( - data: Dict[str, dict], - system, - time_period: int, -) -> Dict[str, pd.DataFrame]: - """Split the data dict into per-entity DataFrames. - - Groups variables by entity, builds one DataFrame per entity. - Group entities are deduplicated by their ID column. - """ - ENTITIES = [ - "person", - "household", - "tax_unit", - "spm_unit", - "family", - "marital_unit", - ] - entity_vars: Dict[str, list] = {e: [] for e in ENTITIES} - - for var_name in sorted(data.keys()): - if var_name in system.variables: - ek = system.variables[var_name].entity.key - if ek in entity_vars: - entity_vars[ek].append(var_name) - else: - entity_vars["household"].append(var_name) - - entity_dfs: Dict[str, pd.DataFrame] = {} - for entity in ENTITIES: - id_col = f"{entity}_id" - cols = {} - for var_name in entity_vars[entity]: - periods = data[var_name] - tp_key = time_period if time_period in periods else str(time_period) - if tp_key not in periods: - continue - arr = periods[tp_key] - if hasattr(arr, "dtype") and arr.dtype.kind == "S": - arr = np.char.decode(arr, "utf-8") - cols[var_name] = arr - - if entity == "person": - for ref_entity in ENTITIES[1:]: - ref_col = f"person_{ref_entity}_id" - if ref_col in data: - periods = data[ref_col] - tp_key = time_period if time_period in periods else str(time_period) - if tp_key in periods: - cols[ref_col] = periods[tp_key] - - if not cols: - continue - - df = pd.DataFrame(cols) - if entity != "person" and id_col in df.columns: - df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) - entity_dfs[entity] = df - - return entity_dfs - - -def _build_uprating_manifest( - data: Dict[str, dict], - system, -) -> pd.DataFrame: - """Build manifest of variable metadata for embedding in HDFStore.""" - records = [] - for var_name in sorted(data.keys()): - entity = ( - system.variables[var_name].entity.key - if var_name in system.variables - else "unknown" - ) - uprating = "" - if var_name in system.variables: - uprating = getattr(system.variables[var_name], "uprating", None) or "" - records.append({"variable": var_name, "entity": entity, "uprating": uprating}) - return pd.DataFrame(records) - - -def _save_hdfstore( - entity_dfs: Dict[str, pd.DataFrame], - manifest_df: pd.DataFrame, - output_path: str, - time_period: int, -) -> str: - """Save entity DataFrames and manifest to a Pandas HDFStore file.""" - hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") - - print(f"\nSaving HDFStore to {hdfstore_path}...") - - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - category=pd.errors.PerformanceWarning, - message=".*PyTables will pickle object types.*", - ) - with pd.HDFStore(hdfstore_path, mode="w") as store: - for entity_name, df in entity_dfs.items(): - for col in df.columns: - if df[col].dtype == object: - df[col] = df[col].astype(str) - store.put(entity_name, df, format="table") - - store.put("_variable_metadata", manifest_df, format="table") - store.put( - "_time_period", - pd.Series([time_period]), - format="table", - ) - - for entity_name, df in entity_dfs.items(): - print(f" {entity_name}: {len(df):,} rows, {len(df.columns)} cols") - print(f" manifest: {len(manifest_df)} variables") - print("HDFStore saved successfully!") - - return hdfstore_path - - def build_h5( weights: np.ndarray, geography, @@ -687,9 +571,9 @@ def build_h5( print(f"Total population (person weights): {pw.sum():,.0f}") # === HDFStore output (entity-level format) === - entity_dfs = _split_data_into_entity_dfs(data, sim.tax_benefit_system, time_period) - manifest_df = _build_uprating_manifest(data, sim.tax_benefit_system) - _save_hdfstore(entity_dfs, manifest_df, str(output_path), time_period) + entity_dfs = split_data_into_entity_dfs(data, sim.tax_benefit_system, time_period) + manifest_df = build_uprating_manifest(data, sim.tax_benefit_system) + save_hdfstore(entity_dfs, manifest_df, str(output_path), time_period) return output_path diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py index 3ddfa3ce..722064b5 100644 --- a/policyengine_us_data/tests/test_format_comparison.py +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -3,9 +3,9 @@ This is a one-off script used to verify that the h5py-to-HDFStore conversion logic is correct. It reads an existing h5py dataset file, -converts it to entity-level Pandas HDFStore using the same splitting/dedup -logic as stacked_dataset_builder, then compares all variables to verify -the conversion is lossless. +converts it to entity-level Pandas HDFStore using the production +splitting/dedup logic, then compares all variables to verify the +conversion is lossless. This script is NOT part of the regular test suite and is not intended to be run in CI. It exists to validate the HDFStore serialization logic @@ -18,21 +18,18 @@ import argparse import sys -import warnings import h5py import numpy as np import pandas as pd import pytest -ENTITIES = [ - "person", - "household", - "tax_unit", - "spm_unit", - "family", - "marital_unit", -] +from policyengine_us_data.utils.hdfstore import ( + ENTITIES, + split_data_into_entity_dfs, + build_uprating_manifest, + save_hdfstore, +) def _load_system(): @@ -43,8 +40,9 @@ def _load_system(): # --------------------------------------------------------------------------- -# h5py -> HDFStore conversion (self-contained reproduction of the builder -# logic so we don't need to import stacked_dataset_builder and its heavy deps) +# h5py reading helpers (test-specific; reads the flat h5py format +# and wraps it into the nested {var: {period: array}} structure +# expected by the production HDFStore utilities) # --------------------------------------------------------------------------- @@ -54,12 +52,9 @@ def _read_h5py_arrays(h5py_path: str): The h5py format stores ``variable / period -> array``. Periods can be yearly (``"2024"``), monthly (``"2024-01"``), or ``"ETERNITY"``. - Some h5py files are fully person-level (all arrays have the same length). - Others are already entity-level: group-entity variables have fewer rows - than person-level variables. - - Returns ``(arrays, time_period, h5_vars)`` where arrays is a dict of - ``{variable_name: numpy_array}``. + Returns ``(data, time_period, h5_vars)`` where *data* is a nested dict + ``{variable_name: {period_key: numpy_array}}`` matching the format + used by the production HDFStore utilities. """ with h5py.File(h5py_path, "r") as f: h5_vars = sorted(f.keys()) @@ -78,7 +73,7 @@ def _read_h5py_arrays(h5py_path: str): raise ValueError("Could not determine year from h5py file") time_period = int(year) - arrays = {} + data = {} for var in h5_vars: subkeys = list(f[var].keys()) @@ -94,103 +89,10 @@ def _read_h5py_arrays(h5py_path: str): arr = np.array( [x.decode() if isinstance(x, bytes) else str(x) for x in arr] ) - arrays[var] = arr - - return arrays, time_period, h5_vars - + # Wrap in nested dict keyed by the period string + data[var] = {period_key: arr} -def _split_into_entity_dfs(arrays, system, vars_to_save): - """Build entity-level DataFrames from a dict of variable arrays. - - ``arrays`` maps variable names to numpy arrays. Arrays may already be - at entity-level (different lengths for different entities) or all at - person-level. We group variables by entity, then build one DataFrame - per entity using arrays of matching length. - """ - entity_cols = {e: [] for e in ENTITIES} - - for var in sorted(vars_to_save): - if var not in arrays: - continue - if var in system.variables: - entity_key = system.variables[var].entity.key - entity_cols[entity_key].append(var) - else: - entity_cols["household"].append(var) - - # Person DataFrame: person vars + entity membership IDs - person_vars = entity_cols["person"][:] - if "person_id" not in person_vars and "person_id" in arrays: - person_vars.insert(0, "person_id") - for entity in ENTITIES[1:]: - ref_col = f"person_{entity}_id" - if ref_col in arrays: - person_vars.append(ref_col) - - person_df = pd.DataFrame({v: arrays[v] for v in person_vars if v in arrays}) - entity_dfs = {"person": person_df} - - # Group entity DataFrames - for entity in ENTITIES[1:]: - id_col = f"{entity}_id" - vars_for_entity = entity_cols[entity][:] - if id_col not in vars_for_entity and id_col in arrays: - vars_for_entity.insert(0, id_col) - - if not vars_for_entity: - continue - - # Check if the arrays are already at entity level (shorter than - # person) or at person level (same length as person_id) - n_persons = len(arrays.get("person_id", [])) - sample_len = len(arrays[vars_for_entity[0]]) - - df_data = {v: arrays[v] for v in vars_for_entity if v in arrays} - df = pd.DataFrame(df_data) - - if sample_len == n_persons and id_col in df.columns: - # Person-level: need to deduplicate by entity ID - df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) - - entity_dfs[entity] = df - - return entity_dfs - - -def _build_uprating_manifest(vars_to_save, system): - """Build manifest of variable metadata.""" - records = [] - for var in sorted(vars_to_save): - entity = ( - system.variables[var].entity.key if var in system.variables else "unknown" - ) - uprating = "" - if var in system.variables: - uprating = getattr(system.variables[var], "uprating", None) or "" - records.append({"variable": var, "entity": entity, "uprating": uprating}) - return pd.DataFrame(records) - - -def _save_hdfstore(entity_dfs, manifest_df, hdfstore_path, time_period): - """Save entity DataFrames and manifest to a Pandas HDFStore file.""" - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - category=pd.errors.PerformanceWarning, - message=".*PyTables will pickle object types.*", - ) - with pd.HDFStore(hdfstore_path, mode="w") as store: - for entity_name, df in entity_dfs.items(): - # Deduplicate column names (can happen if a var appears - # in multiple entity buckets) - df = df.loc[:, ~df.columns.duplicated()] - for col in df.columns: - if df[col].dtype == object: - df[col] = df[col].astype(str) - store.put(entity_name, df, format="table") - store.put("_variable_metadata", manifest_df, format="table") - store.put("_time_period", pd.Series([time_period]), format="table") - return hdfstore_path + return data, time_period, h5_vars # --------------------------------------------------------------------------- @@ -201,22 +103,25 @@ def _save_hdfstore(entity_dfs, manifest_df, hdfstore_path, time_period): def h5py_to_hdfstore(h5py_path: str, hdfstore_path: str) -> dict: """Convert an h5py variable-centric file to entity-level HDFStore. + Uses the production HDFStore utilities so this test validates the + real code path rather than a local reimplementation. + Returns a summary dict with entity row counts. """ print("Loading policyengine-us system (this takes a minute)...") system = _load_system() print("Reading h5py file...") - arrays, time_period, h5_vars = _read_h5py_arrays(h5py_path) - n_persons = len(arrays.get("person_id", [])) + data, time_period, h5_vars = _read_h5py_arrays(h5py_path) + n_persons = len(next(iter(data.get("person_id", {}).values()), [])) print(f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}") print("Splitting into entity DataFrames...") - entity_dfs = _split_into_entity_dfs(arrays, system, h5_vars) - manifest_df = _build_uprating_manifest(h5_vars, system) + entity_dfs = split_data_into_entity_dfs(data, system, time_period) + manifest_df = build_uprating_manifest(data, system) print(f"Saving HDFStore to {hdfstore_path}...") - _save_hdfstore(entity_dfs, manifest_df, hdfstore_path, time_period) + save_hdfstore(entity_dfs, manifest_df, hdfstore_path, time_period) summary = {} for entity_name, df in entity_dfs.items(): diff --git a/policyengine_us_data/utils/hdfstore.py b/policyengine_us_data/utils/hdfstore.py new file mode 100644 index 00000000..9a43a0ef --- /dev/null +++ b/policyengine_us_data/utils/hdfstore.py @@ -0,0 +1,175 @@ +""" +HDFStore serialization utilities. + +Converts variable-centric data dicts (``{var: {period: array}}``) into +entity-level Pandas HDFStore files consumed by API v2 and +``extend_single_year_dataset()``. +""" + +import warnings +from typing import Dict + +import numpy as np +import pandas as pd + +ENTITIES = [ + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", +] + + +def split_data_into_entity_dfs( + data: Dict[str, dict], + system, + time_period: int, +) -> Dict[str, pd.DataFrame]: + """Split the data dict into per-entity DataFrames. + + Args: + data: Maps variable names to ``{period: array}`` dicts. + system: A PolicyEngine tax-benefit system. + time_period: Year to extract from each variable's period dict. + + Returns: + One DataFrame per entity, keyed by entity name. + Group entities are deduplicated by their ID column. + """ + entity_vars: Dict[str, list] = {e: [] for e in ENTITIES} + + for var_name in sorted(data.keys()): + if var_name in system.variables: + ek = system.variables[var_name].entity.key + if ek in entity_vars: + entity_vars[ek].append(var_name) + else: + entity_vars["household"].append(var_name) + + entity_dfs: Dict[str, pd.DataFrame] = {} + for entity in ENTITIES: + id_col = f"{entity}_id" + cols = {} + for var_name in entity_vars[entity]: + periods = data[var_name] + tp_key = ( + time_period if time_period in periods else str(time_period) + ) + if tp_key not in periods: + continue + arr = periods[tp_key] + if hasattr(arr, "dtype") and arr.dtype.kind == "S": + arr = np.char.decode(arr, "utf-8") + cols[var_name] = arr + + if entity == "person": + for ref_entity in ENTITIES[1:]: + ref_col = f"person_{ref_entity}_id" + if ref_col in data: + periods = data[ref_col] + tp_key = ( + time_period + if time_period in periods + else str(time_period) + ) + if tp_key in periods: + cols[ref_col] = periods[tp_key] + + if not cols: + continue + + df = pd.DataFrame(cols) + if entity != "person" and id_col in df.columns: + df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) + entity_dfs[entity] = df + + return entity_dfs + + +def build_uprating_manifest( + data: Dict[str, dict], + system, +) -> pd.DataFrame: + """Build manifest of variable metadata for embedding in HDFStore. + + Args: + data: Maps variable names to ``{period: array}`` dicts. + system: A PolicyEngine tax-benefit system. + + Returns: + DataFrame with columns: variable, entity, uprating. + """ + records = [] + for var_name in sorted(data.keys()): + entity = ( + system.variables[var_name].entity.key + if var_name in system.variables + else "unknown" + ) + uprating = "" + if var_name in system.variables: + uprating = ( + getattr(system.variables[var_name], "uprating", None) or "" + ) + records.append( + { + "variable": var_name, + "entity": entity, + "uprating": uprating, + } + ) + return pd.DataFrame(records) + + +def save_hdfstore( + entity_dfs: Dict[str, pd.DataFrame], + manifest_df: pd.DataFrame, + output_path: str, + time_period: int, +) -> str: + """Save entity DataFrames and manifest to a Pandas HDFStore file. + + Args: + entity_dfs: One DataFrame per entity from + :func:`split_data_into_entity_dfs`. + manifest_df: Variable metadata from + :func:`build_uprating_manifest`. + output_path: Path to the base ``.h5`` file. The HDFStore is + written alongside it with a ``.hdfstore.h5`` suffix. + time_period: Year stored as metadata inside the HDFStore. + + Returns: + Path to the created HDFStore file. + """ + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") + + print(f"\nSaving HDFStore to {hdfstore_path}...") + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + message=".*PyTables will pickle object types.*", + ) + with pd.HDFStore(hdfstore_path, mode="w") as store: + for entity_name, df in entity_dfs.items(): + for col in df.columns: + if df[col].dtype == object: + df[col] = df[col].astype(str) + store.put(entity_name, df, format="table") + + store.put("_variable_metadata", manifest_df, format="table") + store.put( + "_time_period", + pd.Series([time_period]), + format="table", + ) + + for entity_name, df in entity_dfs.items(): + print(f" {entity_name}: {len(df):,} rows, " f"{len(df.columns)} cols") + print(f" manifest: {len(manifest_df)} variables") + print("HDFStore saved successfully!") + + return hdfstore_path From 69aa1f072ef7ec437623fb4e7288de9953955b01 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 16 Mar 2026 20:50:43 +0100 Subject: [PATCH 6/8] Apply black formatting (79 char line length) Co-Authored-By: Claude Opus 4.6 --- .github/bump_version.py | 4 +- modal_app/data_build.py | 24 ++- modal_app/local_area.py | 56 ++++-- modal_app/remote_calibration_runner.py | 40 +++- modal_app/worker_script.py | 4 +- paper/scripts/build_from_content.py | 36 +++- paper/scripts/calculate_target_performance.py | 3 +- paper/scripts/generate_all_tables.py | 8 +- paper/scripts/generate_validation_metrics.py | 4 +- paper/scripts/markdown_to_latex.py | 16 +- .../calibration/block_assignment.py | 32 +++- .../calibration/calibration_utils.py | 16 +- .../calibration/clone_and_assign.py | 4 +- .../calibration/county_assignment.py | 4 +- .../calibration/create_source_imputed_cps.py | 12 +- .../calibration/create_stratified_cps.py | 17 +- .../calibration/publish_local_area.py | 82 ++++++--- .../calibration/puf_impute.py | 50 +++-- .../calibration/sanity_checks.py | 4 +- .../calibration/source_impute.py | 47 +++-- .../calibration/stacked_dataset_builder.py | 4 +- .../calibration/unified_calibration.py | 43 +++-- .../calibration/unified_matrix_builder.py | 136 ++++++++++---- .../calibration/validate_national_h5.py | 4 +- .../calibration/validate_package.py | 24 ++- .../calibration/validate_staging.py | 28 ++- policyengine_us_data/datasets/acs/acs.py | 12 +- .../datasets/acs/census_acs.py | 22 ++- .../datasets/cps/census_cps.py | 32 +++- policyengine_us_data/datasets/cps/cps.py | 171 +++++++++++++----- .../datasets/cps/enhanced_cps.py | 32 +++- .../check_calibrated_estimates_interactive.py | 66 ++++--- .../cps/long_term/extract_ssa_costs.py | 4 +- .../cps/long_term/projection_utils.py | 16 +- .../cps/long_term/run_household_projection.py | 96 ++++++---- .../datasets/cps/small_enhanced_cps.py | 15 +- policyengine_us_data/datasets/puf/irs_puf.py | 4 +- policyengine_us_data/datasets/puf/puf.py | 39 ++-- policyengine_us_data/datasets/scf/fed_scf.py | 16 +- policyengine_us_data/datasets/scf/scf.py | 36 +++- policyengine_us_data/datasets/sipp/sipp.py | 3 +- .../db/create_database_tables.py | 36 +++- .../db/create_initial_strata.py | 16 +- policyengine_us_data/db/etl_age.py | 8 +- policyengine_us_data/db/etl_irs_soi.py | 79 +++++--- policyengine_us_data/db/etl_medicaid.py | 12 +- .../db/etl_national_targets.py | 52 ++++-- policyengine_us_data/db/etl_pregnancy.py | 12 +- policyengine_us_data/db/etl_snap.py | 8 +- .../db/etl_state_income_tax.py | 10 +- policyengine_us_data/db/validate_database.py | 4 +- policyengine_us_data/db/validate_hierarchy.py | 52 ++++-- policyengine_us_data/geography/__init__.py | 4 +- policyengine_us_data/geography/county_fips.py | 8 +- .../geography/create_zip_code_dataset.py | 4 +- policyengine_us_data/parameters/__init__.py | 4 +- .../calibration_targets/audit_county_enum.py | 4 +- .../make_block_cd_distributions.py | 8 +- .../make_block_crosswalk.py | 16 +- .../make_county_cd_distributions.py | 16 +- .../make_district_mapping.py | 8 +- .../pull_hardcoded_targets.py | 8 +- .../calibration_targets/pull_snap_targets.py | 8 +- .../calibration_targets/pull_soi_targets.py | 87 ++++++--- .../storage/upload_completed_datasets.py | 8 +- .../tests/test_calibration/conftest.py | 4 +- .../test_calibration/create_test_fixture.py | 32 +++- .../test_build_matrix_masking.py | 26 ++- .../test_calibration/test_clone_and_assign.py | 13 +- .../test_county_assignment.py | 8 +- .../tests/test_calibration/test_puf_impute.py | 8 +- .../test_retirement_imputation.py | 109 +++++++---- .../test_calibration/test_source_impute.py | 4 +- .../test_stacked_dataset_builder.py | 58 +++--- .../test_calibration/test_target_config.py | 8 +- .../test_unified_calibration.py | 36 +++- .../test_unified_matrix_builder.py | 55 ++++-- .../test_calibration/test_xw_consistency.py | 11 +- .../tests/test_constraint_validation.py | 12 +- .../tests/test_database_build.py | 28 +-- .../tests/test_datasets/test_county_fips.py | 8 +- .../tests/test_datasets/test_cps.py | 17 +- .../test_datasets/test_dataset_sanity.py | 50 +++-- .../tests/test_datasets/test_enhanced_cps.py | 54 ++++-- .../tests/test_datasets/test_sipp_assets.py | 28 +-- .../test_datasets/test_small_enhanced_cps.py | 10 +- .../test_datasets/test_sparse_enhanced_cps.py | 28 ++- .../tests/test_format_comparison.py | 55 ++++-- policyengine_us_data/tests/test_puf_impute.py | 4 +- .../tests/test_schema_views_and_lookups.py | 12 +- policyengine_us_data/utils/census.py | 4 +- .../utils/constraint_validation.py | 16 +- policyengine_us_data/utils/data_upload.py | 20 +- policyengine_us_data/utils/db.py | 29 ++- policyengine_us_data/utils/huggingface.py | 12 +- policyengine_us_data/utils/loss.py | 147 ++++++++++----- policyengine_us_data/utils/randomness.py | 4 +- policyengine_us_data/utils/soi.py | 27 ++- policyengine_us_data/utils/spm.py | 4 +- policyengine_us_data/utils/uprating.py | 4 +- tests/test_h6_reform.py | 18 +- tests/test_no_formula_variables_stored.py | 26 ++- tests/test_reproducibility.py | 6 +- tests/test_weeks_unemployed.py | 6 +- validation/benefit_validation.py | 20 +- validation/generate_qrf_statistics.py | 32 +++- validation/qrf_diagnostics.py | 40 ++-- validation/tax_policy_validation.py | 8 +- validation/validate_retirement_imputation.py | 20 +- 109 files changed, 2025 insertions(+), 834 deletions(-) diff --git a/.github/bump_version.py b/.github/bump_version.py index 779a82e3..bb0fd6dd 100644 --- a/.github/bump_version.py +++ b/.github/bump_version.py @@ -19,7 +19,9 @@ def get_current_version(pyproject_path: Path) -> str: def infer_bump(changelog_dir: Path) -> str: fragments = [ - f for f in changelog_dir.iterdir() if f.is_file() and f.name != ".gitkeep" + f + for f in changelog_dir.iterdir() + if f.is_file() and f.name != ".gitkeep" ] if not fragments: print("No changelog fragments found", file=sys.stderr) diff --git a/modal_app/data_build.py b/modal_app/data_build.py index 197ee32b..2a0310c4 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -21,7 +21,9 @@ ) image = ( - modal.Image.debian_slim(python_version="3.13").apt_install("git").pip_install("uv") + modal.Image.debian_slim(python_version="3.13") + .apt_install("git") + .pip_install("uv") ) REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" @@ -90,7 +92,9 @@ def setup_gcp_credentials(): @functools.cache def get_current_commit() -> str: """Get the current git commit SHA (cached per process).""" - return subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], text=True + ).strip() def get_checkpoint_path(branch: str, output_file: str) -> Path: @@ -400,7 +404,9 @@ def build_datasets( print("=== Phase 3: Building extended CPS ===") run_script_with_checkpoint( "policyengine_us_data/datasets/cps/extended_cps.py", - SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/extended_cps.py"], + SCRIPT_OUTPUTS[ + "policyengine_us_data/datasets/cps/extended_cps.py" + ], branch, checkpoint_volume, env=env, @@ -408,13 +414,17 @@ def build_datasets( # GROUP 3: After extended_cps - run in parallel # enhanced_cps and stratified_cps both depend on extended_cps - print("=== Phase 4: Building enhanced and stratified CPS (parallel) ===") + print( + "=== Phase 4: Building enhanced and stratified CPS (parallel) ===" + ) with ThreadPoolExecutor(max_workers=2) as executor: futures = [ executor.submit( run_script_with_checkpoint, "policyengine_us_data/datasets/cps/enhanced_cps.py", - SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/enhanced_cps.py"], + SCRIPT_OUTPUTS[ + "policyengine_us_data/datasets/cps/enhanced_cps.py" + ], branch, checkpoint_volume, env=env, @@ -437,7 +447,9 @@ def build_datasets( print("=== Phase 5: Building small enhanced CPS ===") run_script_with_checkpoint( "policyengine_us_data/datasets/cps/small_enhanced_cps.py", - SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/small_enhanced_cps.py"], + SCRIPT_OUTPUTS[ + "policyengine_us_data/datasets/cps/small_enhanced_cps.py" + ], branch, checkpoint_volume, env=env, diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 0b0670d2..f13ae216 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -104,7 +104,9 @@ def validate_artifacts( artifacts = config.get("artifacts", {}) if not artifacts: - print("WARNING: No artifacts section in run config, skipping validation") + print( + "WARNING: No artifacts section in run config, skipping validation" + ) return for filename, expected_hash in artifacts.items(): @@ -126,7 +128,9 @@ def validate_artifacts( f" Actual: {actual}" ) - print(f"Validated {len(artifacts)} artifact(s) against run config checksums") + print( + f"Validated {len(artifacts)} artifact(s) against run config checksums" + ) def get_version() -> str: @@ -207,11 +211,15 @@ def run_phase( version_dir: Path, ) -> set: """Run a single build phase, spawning workers and collecting results.""" - work_chunks = partition_work(states, districts, cities, num_workers, completed) + work_chunks = partition_work( + states, districts, cities, num_workers, completed + ) total_remaining = sum(len(c) for c in work_chunks) print(f"\n--- Phase: {phase_name} ---") - print(f"Remaining work: {total_remaining} items across {len(work_chunks)} workers") + print( + f"Remaining work: {total_remaining} items across {len(work_chunks)} workers" + ) if total_remaining == 0: print(f"All {phase_name} items already built!") @@ -400,7 +408,9 @@ def validate_staging(branch: str, version: str) -> Dict: print(f" States: {manifest['totals']['states']}") print(f" Districts: {manifest['totals']['districts']}") print(f" Cities: {manifest['totals']['cities']}") - print(f" Total size: {manifest['totals']['total_size_bytes'] / 1e9:.2f} GB") + print( + f" Total size: {manifest['totals']['total_size_bytes'] / 1e9:.2f} GB" + ) return manifest @@ -559,9 +569,7 @@ def promote_publish(branch: str = "main", version: str = "") -> str: if result.returncode != 0: raise RuntimeError(f"Promote failed: {result.stderr}") - return ( - f"Successfully promoted version {version} with {len(manifest['files'])} files" - ) + return f"Successfully promoted version {version} with {len(manifest['files'])} files" @app.function( @@ -613,11 +621,15 @@ def coordinate_publish( "dataset": dataset_path, "database": db_path, "geography": (calibration_dir / "calibration" / "geography.npz"), - "run_config": (calibration_dir / "calibration" / "unified_run_config.json"), + "run_config": ( + calibration_dir / "calibration" / "unified_run_config.json" + ), } for label, p in required.items(): if not p.exists(): - raise RuntimeError(f"Missing required calibration input ({label}): {p}") + raise RuntimeError( + f"Missing required calibration input ({label}): {p}" + ) print("All required calibration inputs found on volume.") else: if calibration_dir.exists(): @@ -646,11 +658,15 @@ def coordinate_publish( print("Calibration inputs downloaded") dataset_path = ( - calibration_dir / "calibration" / "source_imputed_stratified_extended_cps.h5" + calibration_dir + / "calibration" + / "source_imputed_stratified_extended_cps.h5" ) geo_npz_path = calibration_dir / "calibration" / "geography.npz" - config_json_path = calibration_dir / "calibration" / "unified_run_config.json" + config_json_path = ( + calibration_dir / "calibration" / "unified_run_config.json" + ) calibration_inputs = { "weights": str(weights_path), "dataset": str(dataset_path), @@ -766,10 +782,14 @@ def coordinate_publish( ) if actual_total < expected_total: - print(f"WARNING: Expected {expected_total} files, found {actual_total}") + print( + f"WARNING: Expected {expected_total} files, found {actual_total}" + ) print("\nStarting upload to staging...") - result = upload_to_staging.remote(branch=branch, version=version, manifest=manifest) + result = upload_to_staging.remote( + branch=branch, version=version, manifest=manifest + ) print(result) print("\n" + "=" * 60) @@ -853,10 +873,14 @@ def coordinate_national_publish( staging_volume.commit() print("National calibration inputs downloaded") - weights_path = calibration_dir / "calibration" / "national_calibration_weights.npy" + weights_path = ( + calibration_dir / "calibration" / "national_calibration_weights.npy" + ) db_path = calibration_dir / "calibration" / "policy_data.db" dataset_path = ( - calibration_dir / "calibration" / "source_imputed_stratified_extended_cps.h5" + calibration_dir + / "calibration" + / "source_imputed_stratified_extended_cps.h5" ) geo_npz_path = calibration_dir / "calibration" / "national_geography.npz" diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 075d5948..4853c719 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -5,10 +5,14 @@ app = modal.App("policyengine-us-data-fit-weights") hf_secret = modal.Secret.from_name("huggingface-token") -calibration_vol = modal.Volume.from_name("calibration-data", create_if_missing=True) +calibration_vol = modal.Volume.from_name( + "calibration-data", create_if_missing=True +) image = ( - modal.Image.debian_slim(python_version="3.11").apt_install("git").pip_install("uv") + modal.Image.debian_slim(python_version="3.11") + .apt_install("git") + .pip_install("uv") ) REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" @@ -48,7 +52,9 @@ def _clone_and_install(branch: str): subprocess.run(["uv", "sync", "--extra", "l0"], check=True) -def _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq=None): +def _append_hyperparams( + cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq=None +): """Append optional hyperparameter flags to a command list.""" if beta is not None: cmd.extend(["--beta", str(beta)]) @@ -265,7 +271,9 @@ def _fit_weights_impl( cmd.append("--county-level") if workers > 1: cmd.extend(["--workers", str(workers)]) - _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq) + _append_hyperparams( + cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq + ) cal_rc, cal_lines = _run_streaming( cmd, @@ -322,7 +330,9 @@ def _fit_from_package_impl( ] if target_config: cmd.extend(["--target-config", target_config]) - _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq) + _append_hyperparams( + cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq + ) print(f"Running command: {' '.join(cmd)}", flush=True) @@ -337,7 +347,9 @@ def _fit_from_package_impl( return _collect_outputs(cal_lines) -def _print_provenance_from_meta(meta: dict, current_branch: str = None) -> None: +def _print_provenance_from_meta( + meta: dict, current_branch: str = None +) -> None: """Print provenance info and warn on branch mismatch.""" built = meta.get("created_at", "unknown") branch = meta.get("git_branch", "unknown") @@ -514,7 +526,9 @@ def check_volume_package() -> dict: return {"exists": False} stat = os.stat(pkg_path) - mtime = datetime.datetime.fromtimestamp(stat.st_mtime, tz=datetime.timezone.utc) + mtime = datetime.datetime.fromtimestamp( + stat.st_mtime, tz=datetime.timezone.utc + ) info = { "exists": True, "size": stat.st_size, @@ -1012,7 +1026,9 @@ def main( if vol_info.get("created_at") or vol_info.get("git_branch"): _print_provenance_from_meta(vol_info, branch) mode_label = ( - "national calibration" if national else "fitting from pre-built package" + "national calibration" + if national + else "fitting from pre-built package" ) print( "========================================", @@ -1105,8 +1121,12 @@ def main( upload_calibration_artifacts( weights_path=output, blocks_path=(blocks_output if result.get("blocks") else None), - geo_labels_path=(geo_labels_output if result.get("geo_labels") else None), - geography_path=(geography_output if result.get("geography") else None), + geo_labels_path=( + geo_labels_output if result.get("geo_labels") else None + ), + geography_path=( + geography_output if result.get("geography") else None + ), log_dir=".", prefix=prefix, ) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index f36b59a0..3267d0fb 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -94,7 +94,9 @@ def main(): if state_fips is None: raise ValueError(f"Unknown state code: {item_id}") cd_subset = [ - cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips + cd + for cd in cds_to_calibrate + if int(cd) // 100 == state_fips ] if not cd_subset: print( diff --git a/paper/scripts/build_from_content.py b/paper/scripts/build_from_content.py index 52f88389..21068f0d 100644 --- a/paper/scripts/build_from_content.py +++ b/paper/scripts/build_from_content.py @@ -47,8 +47,12 @@ def md_to_latex(self, content, section_type="section"): latex = re.sub(r"^# Abstract\n\n", "", latex) else: # Convert markdown headers to LaTeX sections - latex = re.sub(r"^# (.+)$", r"\\section{\1}", latex, flags=re.MULTILINE) - latex = re.sub(r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE) + latex = re.sub( + r"^# (.+)$", r"\\section{\1}", latex, flags=re.MULTILINE + ) + latex = re.sub( + r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE + ) latex = re.sub( r"^### (.+)$", r"\\subsubsection{\1}", @@ -169,11 +173,15 @@ def convert_citation(match): if len(author_list) == 1: # Handle "Author1 and Author2" format if " and " in authors: - first_author = authors.split(" and ")[0].strip().split()[-1] + first_author = ( + authors.split(" and ")[0].strip().split()[-1] + ) cite_key = f"{first_author.lower()}{year}" else: # Single author - author = author_list[0].strip().split()[-1] # Last name + author = ( + author_list[0].strip().split()[-1] + ) # Last name cite_key = f"{author.lower()}{year}" else: # Multiple authors - use first author @@ -183,7 +191,9 @@ def convert_citation(match): return f"\\citep{{{cite_key}}}" return match.group(0) # Return original if no year found - latex = re.sub(r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_citation, latex) + latex = re.sub( + r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_citation, latex + ) # Also handle inline citations like "Author (Year)" or "Author et al. (Year)" def convert_inline_citation(match): @@ -266,11 +276,15 @@ def convert_myst_citation(match): if len(author_list) == 1: # Handle "Author1 and Author2" format if " and " in authors: - first_author = authors.split(" and ")[0].strip().split()[-1] + first_author = ( + authors.split(" and ")[0].strip().split()[-1] + ) cite_key = f"{first_author.lower()}{year}" else: # Single author - author = author_list[0].strip().split()[-1] # Last name + author = ( + author_list[0].strip().split()[-1] + ) # Last name cite_key = f"{author.lower()}{year}" else: # Multiple authors - use first author @@ -280,7 +294,9 @@ def convert_myst_citation(match): return f"{{cite}}`{cite_key}`" return match.group(0) - myst = re.sub(r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_myst_citation, myst) + myst = re.sub( + r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_myst_citation, myst + ) # Handle inline citations like "Author (Year)" - convert to {cite:t}`author_year` def convert_inline_myst(match): @@ -327,7 +343,9 @@ def process_content_file(self, content_file): # LaTeX conversion if stem == "abstract": latex_content = self.md_to_latex(content, section_type="abstract") - latex_content = f"\\begin{{abstract}}\n{latex_content}\n\\end{{abstract}}" + latex_content = ( + f"\\begin{{abstract}}\n{latex_content}\n\\end{{abstract}}" + ) latex_path = self.paper_dir / "abstract.tex" elif stem == "introduction": latex_content = self.md_to_latex(content) diff --git a/paper/scripts/calculate_target_performance.py b/paper/scripts/calculate_target_performance.py index 8f5a65f1..1a50ab3c 100644 --- a/paper/scripts/calculate_target_performance.py +++ b/paper/scripts/calculate_target_performance.py @@ -79,7 +79,8 @@ def compare_dataset_performance( # Calculate average improvement by target category categories = { - "IRS Income": lambda x: "employment_income" in x or "capital_gains" in x, + "IRS Income": lambda x: "employment_income" in x + or "capital_gains" in x, "Demographics": lambda x: "age_" in x or "population" in x, "Programs": lambda x: "snap" in x or "social_security" in x, "Tax Expenditures": lambda x: "salt" in x or "charitable" in x, diff --git a/paper/scripts/generate_all_tables.py b/paper/scripts/generate_all_tables.py index 690b528d..8f476203 100644 --- a/paper/scripts/generate_all_tables.py +++ b/paper/scripts/generate_all_tables.py @@ -33,7 +33,9 @@ def create_latex_table(df, caption, label, float_format=None): # Format the dataframe as LaTeX if float_format: - table_body = df.to_latex(index=False, escape=False, float_format=float_format) + table_body = df.to_latex( + index=False, escape=False, float_format=float_format + ) else: table_body = df.to_latex(index=False, escape=False) @@ -42,7 +44,9 @@ def create_latex_table(df, caption, label, float_format=None): tabular_start = next( i for i, line in enumerate(lines) if "\\begin{tabular}" in line ) - tabular_end = next(i for i, line in enumerate(lines) if "\\end{tabular}" in line) + tabular_end = next( + i for i, line in enumerate(lines) if "\\end{tabular}" in line + ) # Indent the tabular content for i in range(tabular_start, tabular_end + 1): diff --git a/paper/scripts/generate_validation_metrics.py b/paper/scripts/generate_validation_metrics.py index 90b3624d..db586959 100644 --- a/paper/scripts/generate_validation_metrics.py +++ b/paper/scripts/generate_validation_metrics.py @@ -235,7 +235,9 @@ def main(): print(f"\nResults saved to {results_dir}/") print("\nNOTE: All metrics marked as [TO BE CALCULATED] require full") - print("dataset generation and microsimulation runs to compute actual values.") + print( + "dataset generation and microsimulation runs to compute actual values." + ) if __name__ == "__main__": diff --git a/paper/scripts/markdown_to_latex.py b/paper/scripts/markdown_to_latex.py index 7cc80b04..5c3b0e3b 100644 --- a/paper/scripts/markdown_to_latex.py +++ b/paper/scripts/markdown_to_latex.py @@ -24,8 +24,12 @@ def convert_markdown_to_latex(markdown_content: str) -> str: # Convert headers latex = re.sub(r"^# (.+)$", r"\\section{\1}", latex, flags=re.MULTILINE) - latex = re.sub(r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE) - latex = re.sub(r"^### (.+)$", r"\\subsubsection{\1}", latex, flags=re.MULTILINE) + latex = re.sub( + r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE + ) + latex = re.sub( + r"^### (.+)$", r"\\subsubsection{\1}", latex, flags=re.MULTILINE + ) # Convert bold and italic latex = re.sub(r"\*\*(.+?)\*\*", r"\\textbf{\1}", latex) @@ -63,7 +67,9 @@ def convert_markdown_to_latex(markdown_content: str) -> str: # Manage list stack while len(list_stack) > indent_level + 1: - new_lines.append(" " * (len(list_stack) - 1) + "\\end{itemize}") + new_lines.append( + " " * (len(list_stack) - 1) + "\\end{itemize}" + ) list_stack.pop() if len(list_stack) <= indent_level: @@ -75,7 +81,9 @@ def convert_markdown_to_latex(markdown_content: str) -> str: else: # Close any open lists while list_stack: - new_lines.append(" " * (len(list_stack) - 1) + "\\end{itemize}") + new_lines.append( + " " * (len(list_stack) - 1) + "\\end{itemize}" + ) list_stack.pop() new_lines.append(line) in_list = False diff --git a/policyengine_us_data/calibration/block_assignment.py b/policyengine_us_data/calibration/block_assignment.py index 83af388f..3ce09289 100644 --- a/policyengine_us_data/calibration/block_assignment.py +++ b/policyengine_us_data/calibration/block_assignment.py @@ -138,9 +138,7 @@ def _load_cbsa_crosswalk() -> Dict[str, str]: Returns: Dict mapping 5-digit county FIPS to CBSA code (or None if not in CBSA) """ - url = ( - "https://data.nber.org/cbsa-csa-fips-county-crosswalk/2023/cbsa2fipsxw_2023.csv" - ) + url = "https://data.nber.org/cbsa-csa-fips-county-crosswalk/2023/cbsa2fipsxw_2023.csv" try: df = pd.read_csv(url, dtype=str) # Build 5-digit county FIPS from state + county codes @@ -272,10 +270,14 @@ def get_all_geography_from_block(block_geoid: str) -> Dict[str, Optional[str]]: result = { "sldu": row["sldu"] if pd.notna(row["sldu"]) else None, "sldl": row["sldl"] if pd.notna(row["sldl"]) else None, - "place_fips": (row["place_fips"] if pd.notna(row["place_fips"]) else None), + "place_fips": ( + row["place_fips"] if pd.notna(row["place_fips"]) else None + ), "vtd": row["vtd"] if pd.notna(row["vtd"]) else None, "puma": row["puma"] if pd.notna(row["puma"]) else None, - "zcta": (row["zcta"] if has_zcta and pd.notna(row["zcta"]) else None), + "zcta": ( + row["zcta"] if has_zcta and pd.notna(row["zcta"]) else None + ), } return result return { @@ -444,11 +446,17 @@ def assign_geography_for_cd( - county_index: int32 indices into County enum (for backwards compat) """ # Assign blocks first - block_geoids = assign_blocks_for_cd(cd_geoid, n_households, seed, distributions) + block_geoids = assign_blocks_for_cd( + cd_geoid, n_households, seed, distributions + ) # Derive geography directly from block GEOID structure - county_fips = np.array([get_county_fips_from_block(b) for b in block_geoids]) - tract_geoids = np.array([get_tract_geoid_from_block(b) for b in block_geoids]) + county_fips = np.array( + [get_county_fips_from_block(b) for b in block_geoids] + ) + tract_geoids = np.array( + [get_tract_geoid_from_block(b) for b in block_geoids] + ) state_fips = np.array([get_state_fips_from_block(b) for b in block_geoids]) # CBSA lookup via county (may be None for rural areas) @@ -525,8 +533,12 @@ def derive_geography_from_blocks( Returns: Dict with same keys as assign_geography_for_cd. """ - county_fips = np.array([get_county_fips_from_block(b) for b in block_geoids]) - tract_geoids = np.array([get_tract_geoid_from_block(b) for b in block_geoids]) + county_fips = np.array( + [get_county_fips_from_block(b) for b in block_geoids] + ) + tract_geoids = np.array( + [get_tract_geoid_from_block(b) for b in block_geoids] + ) state_fips = np.array([get_state_fips_from_block(b) for b in block_geoids]) cbsa_codes = np.array([get_cbsa_from_county(c) or "" for c in county_fips]) county_indices = np.array( diff --git a/policyengine_us_data/calibration/calibration_utils.py b/policyengine_us_data/calibration/calibration_utils.py index 9d10ee6a..5cf9f1bc 100644 --- a/policyengine_us_data/calibration/calibration_utils.py +++ b/policyengine_us_data/calibration/calibration_utils.py @@ -352,7 +352,9 @@ def create_target_groups( for domain_var, var_name in pairs: var_mask = ( - (targets_df["variable"] == var_name) & level_mask & ~processed_mask + (targets_df["variable"] == var_name) + & level_mask + & ~processed_mask ) if has_domain and domain_var is not None: var_mask &= targets_df["domain_variable"] == domain_var @@ -378,11 +380,15 @@ def create_target_groups( # Format output based on level and count if n_targets == 1: value = matching["value"].iloc[0] - info_str = f"{level_name} {label} (1 target, value={value:,.0f})" + info_str = ( + f"{level_name} {label} (1 target, value={value:,.0f})" + ) print_str = f" Group {group_id}: {label} = {value:,.0f}" else: info_str = f"{level_name} {label} ({n_targets} targets)" - print_str = f" Group {group_id}: {label} ({n_targets} targets)" + print_str = ( + f" Group {group_id}: {label} ({n_targets} targets)" + ) group_info.append(f"Group {group_id}: {info_str}") print(print_str) @@ -622,7 +628,9 @@ def calculate_spm_thresholds_vectorized( for i in range(n_units): tenure_str = TENURE_CODE_MAP.get(int(tenure_codes[i]), "renter") base = base_thresholds[tenure_str] - equiv_scale = spm_equivalence_scale(int(num_adults[i]), int(num_children[i])) + equiv_scale = spm_equivalence_scale( + int(num_adults[i]), int(num_children[i]) + ) thresholds[i] = base * equiv_scale * spm_unit_geoadj[i] return thresholds diff --git a/policyengine_us_data/calibration/clone_and_assign.py b/policyengine_us_data/calibration/clone_and_assign.py index a140f1b1..bc85dfd8 100644 --- a/policyengine_us_data/calibration/clone_and_assign.py +++ b/policyengine_us_data/calibration/clone_and_assign.py @@ -110,7 +110,9 @@ def assign_random_geography( n_bad = collisions.sum() if n_bad == 0: break - clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs) + clone_indices[collisions] = rng.choice( + len(blocks), size=n_bad, p=probs + ) clone_cds = cds[clone_indices] collisions = np.zeros(n_records, dtype=bool) for prev in range(clone_idx): diff --git a/policyengine_us_data/calibration/county_assignment.py b/policyengine_us_data/calibration/county_assignment.py index a1f262d7..6d32d30b 100644 --- a/policyengine_us_data/calibration/county_assignment.py +++ b/policyengine_us_data/calibration/county_assignment.py @@ -150,7 +150,9 @@ def get_county_filter_probability( else: dist = _generate_uniform_distribution(cd_key) - return sum(prob for county, prob in dist.items() if county in county_filter) + return sum( + prob for county, prob in dist.items() if county in county_filter + ) def get_filtered_county_distribution( diff --git a/policyengine_us_data/calibration/create_source_imputed_cps.py b/policyengine_us_data/calibration/create_source_imputed_cps.py index 68dd876a..4381f72d 100644 --- a/policyengine_us_data/calibration/create_source_imputed_cps.py +++ b/policyengine_us_data/calibration/create_source_imputed_cps.py @@ -19,7 +19,9 @@ logger = logging.getLogger(__name__) INPUT_PATH = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") -OUTPUT_PATH = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") +OUTPUT_PATH = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) def create_source_imputed_cps( @@ -47,7 +49,9 @@ def create_source_imputed_cps( logger.info("Loaded %d households, time_period=%d", n_records, time_period) - geography = assign_random_geography(n_records=n_records, n_clones=1, seed=seed) + geography = assign_random_geography( + n_records=n_records, n_clones=1, seed=seed + ) base_states = geography.state_fips[:n_records] raw_data = sim.dataset.load_dataset() @@ -55,7 +59,9 @@ def create_source_imputed_cps( for var in raw_data: val = raw_data[var] if isinstance(val, dict): - data_dict[var] = {int(k) if k.isdigit() else k: v for k, v in val.items()} + data_dict[var] = { + int(k) if k.isdigit() else k: v for k, v in val.items() + } else: data_dict[var] = {time_period: val[...]} diff --git a/policyengine_us_data/calibration/create_stratified_cps.py b/policyengine_us_data/calibration/create_stratified_cps.py index 2aa15a9f..e2632366 100644 --- a/policyengine_us_data/calibration/create_stratified_cps.py +++ b/policyengine_us_data/calibration/create_stratified_cps.py @@ -79,7 +79,9 @@ def create_stratified_cps_dataset( f" Top {100 - high_income_percentile}% (AGI >= ${high_income_threshold:,.0f}): {n_top:,}" ) print(f" Middle 25-{high_income_percentile}%: {n_middle:,}") - print(f" Bottom 25% (AGI < ${bottom_25_pct_threshold:,.0f}): {n_bottom_25:,}") + print( + f" Bottom 25% (AGI < ${bottom_25_pct_threshold:,.0f}): {n_bottom_25:,}" + ) # Calculate sampling rates # Keep ALL top earners, distribute remaining quota between middle and bottom @@ -130,7 +132,9 @@ def create_stratified_cps_dataset( # Top earners - keep all top_mask = agi >= high_income_threshold selected_mask[top_mask] = True - print(f" Top {100 - high_income_percentile}%: selected {np.sum(top_mask):,}") + print( + f" Top {100 - high_income_percentile}%: selected {np.sum(top_mask):,}" + ) # Bottom 25% bottom_mask = agi < bottom_25_pct_threshold @@ -267,7 +271,10 @@ def create_stratified_cps_dataset( if "person_id" in f and str(time_period) in f["person_id"]: person_ids = f["person_id"][str(time_period)][:] print(f" Final persons: {len(person_ids):,}") - if "household_weight" in f and str(time_period) in f["household_weight"]: + if ( + "household_weight" in f + and str(time_period) in f["household_weight"] + ): weights = f["household_weight"][str(time_period)][:] print(f" Final household weights sum: {np.sum(weights):,.0f}") @@ -335,5 +342,7 @@ def create_stratified_cps_dataset( ) print("\nExamples:") print(" python create_stratified_cps.py 30000") - print(" python create_stratified_cps.py 50000 --top=99.5 --oversample-poor") + print( + " python create_stratified_cps.py 50000 --top=99.5 --oversample-poor" + ) print(" python create_stratified_cps.py 30000 --seed=123 # reproducible") diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index ad3ccb1a..8e505351 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -168,14 +168,17 @@ def build_h5( # CD subset filtering: zero out cells whose CD isn't in subset if cd_subset is not None: cd_subset_set = set(cd_subset) - cd_mask = np.vectorize(lambda cd: cd in cd_subset_set)(clone_cds_matrix) + cd_mask = np.vectorize(lambda cd: cd in cd_subset_set)( + clone_cds_matrix + ) W[~cd_mask] = 0 # County filtering: scale weights by P(target_counties | CD) if county_filter is not None: unique_cds = np.unique(clone_cds_matrix) cd_prob = { - cd: get_county_filter_probability(cd, county_filter) for cd in unique_cds + cd: get_county_filter_probability(cd, county_filter) + for cd in unique_cds } p_matrix = np.vectorize( cd_prob.__getitem__, @@ -202,11 +205,15 @@ def build_h5( ) clone_weights = W[active_geo, active_hh] active_blocks = blocks.reshape(n_clones_total, n_hh)[active_geo, active_hh] - active_clone_cds = clone_cds.reshape(n_clones_total, n_hh)[active_geo, active_hh] + active_clone_cds = clone_cds.reshape(n_clones_total, n_hh)[ + active_geo, active_hh + ] empty_count = np.sum(active_blocks == "") if empty_count > 0: - raise ValueError(f"{empty_count} active clones have empty block GEOIDs") + raise ValueError( + f"{empty_count} active clones have empty block GEOIDs" + ) print(f"Active clones: {n_clones:,}") print(f"Total weight: {clone_weights.sum():,.0f}") @@ -251,12 +258,16 @@ def build_h5( # === Build clone index arrays === hh_clone_idx = active_hh - persons_per_clone = np.array([len(hh_to_persons.get(h, [])) for h in active_hh]) + persons_per_clone = np.array( + [len(hh_to_persons.get(h, [])) for h in active_hh] + ) person_parts = [ np.array(hh_to_persons.get(h, []), dtype=np.int64) for h in active_hh ] person_clone_idx = ( - np.concatenate(person_parts) if person_parts else np.array([], dtype=np.int64) + np.concatenate(person_parts) + if person_parts + else np.array([], dtype=np.int64) ) entity_clone_idx = {} @@ -265,7 +276,8 @@ def build_h5( epc = np.array([len(hh_to_entity[ek].get(h, [])) for h in active_hh]) entities_per_clone[ek] = epc parts = [ - np.array(hh_to_entity[ek].get(h, []), dtype=np.int64) for h in active_hh + np.array(hh_to_entity[ek].get(h, []), dtype=np.int64) + for h in active_hh ] entity_clone_idx[ek] = ( np.concatenate(parts) if parts else np.array([], dtype=np.int64) @@ -304,7 +316,9 @@ def build_h5( sorted_keys = entity_keys[sorted_order] sorted_new = new_entity_ids[ek][sorted_order] - p_old_eids = person_entity_id_arrays[ek][person_clone_idx].astype(np.int64) + p_old_eids = person_entity_id_arrays[ek][person_clone_idx].astype( + np.int64 + ) person_keys = clone_ids_for_persons * offset + p_old_eids positions = np.searchsorted(sorted_keys, person_keys) @@ -459,7 +473,9 @@ def build_h5( } # === Gap 4: Congressional district GEOID === - clone_cd_geoids = np.array([int(cd) for cd in active_clone_cds], dtype=np.int32) + clone_cd_geoids = np.array( + [int(cd) for cd in active_clone_cds], dtype=np.int32 + ) data["congressional_district_geoid"] = { time_period: clone_cd_geoids, } @@ -479,7 +495,9 @@ def build_h5( ) # Get cloned person ages and SPM unit IDs - person_ages = sim.calculate("age", map_to="person").values[person_clone_idx] + person_ages = sim.calculate("age", map_to="person").values[ + person_clone_idx + ] # Get cloned tenure types spm_tenure_holder = sim.get_holder("spm_unit_tenure_type") @@ -571,7 +589,9 @@ def build_h5( print(f"Total population (person weights): {pw.sum():,.0f}") # === HDFStore output (entity-level format) === - entity_dfs = split_data_into_entity_dfs(data, sim.tax_benefit_system, time_period) + entity_dfs = split_data_into_entity_dfs( + data, sim.tax_benefit_system, time_period + ) manifest_df = build_uprating_manifest(data, sim.tax_benefit_system) save_hdfstore(entity_dfs, manifest_df, str(output_path), time_period) @@ -639,7 +659,9 @@ def build_states( if upload: print(f"Uploading {state_code}.h5 to GCP...") - upload_local_area_file(str(output_path), "states", skip_hf=True) + upload_local_area_file( + str(output_path), "states", skip_hf=True + ) # Upload HDFStore file if it exists hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") @@ -657,7 +679,9 @@ def build_states( print(f"Completed {state_code}") if upload and len(hf_queue) >= hf_batch_size: - print(f"\nUploading batch of {len(hf_queue)} files to HuggingFace...") + print( + f"\nUploading batch of {len(hf_queue)} files to HuggingFace..." + ) upload_local_area_batch_to_hf(hf_queue) hf_queue = [] @@ -666,7 +690,9 @@ def build_states( raise if upload and hf_queue: - print(f"\nUploading final batch of {len(hf_queue)} files to HuggingFace...") + print( + f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..." + ) upload_local_area_batch_to_hf(hf_queue) @@ -718,7 +744,9 @@ def build_districts( if upload: print(f"Uploading {friendly_name}.h5 to GCP...") - upload_local_area_file(str(output_path), "districts", skip_hf=True) + upload_local_area_file( + str(output_path), "districts", skip_hf=True + ) # Upload HDFStore file if it exists hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") @@ -736,7 +764,9 @@ def build_districts( print(f"Completed {friendly_name}") if upload and len(hf_queue) >= hf_batch_size: - print(f"\nUploading batch of {len(hf_queue)} files to HuggingFace...") + print( + f"\nUploading batch of {len(hf_queue)} files to HuggingFace..." + ) upload_local_area_batch_to_hf(hf_queue) hf_queue = [] @@ -745,7 +775,9 @@ def build_districts( raise if upload and hf_queue: - print(f"\nUploading final batch of {len(hf_queue)} files to HuggingFace...") + print( + f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..." + ) upload_local_area_batch_to_hf(hf_queue) @@ -792,10 +824,14 @@ def build_cities( if upload: print("Uploading NYC.h5 to GCP...") - upload_local_area_file(str(output_path), "cities", skip_hf=True) + upload_local_area_file( + str(output_path), "cities", skip_hf=True + ) # Upload HDFStore file if it exists - hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") + hdfstore_path = str(output_path).replace( + ".h5", ".hdfstore.h5" + ) if os.path.exists(hdfstore_path): print("Uploading NYC.hdfstore.h5 to GCP...") upload_local_area_file( @@ -814,7 +850,9 @@ def build_cities( raise if upload and hf_queue: - print(f"\nUploading batch of {len(hf_queue)} city files to HuggingFace...") + print( + f"\nUploading batch of {len(hf_queue)} city files to HuggingFace..." + ) upload_local_area_batch_to_hf(hf_queue) @@ -891,7 +929,9 @@ def main(): elif args.skip_download: inputs = { "weights": WORK_DIR / "calibration_weights.npy", - "dataset": (WORK_DIR / "source_imputed_stratified_extended_cps.h5"), + "dataset": ( + WORK_DIR / "source_imputed_stratified_extended_cps.h5" + ), } print("Using existing files in work directory:") for key, path in inputs.items(): diff --git a/policyengine_us_data/calibration/puf_impute.py b/policyengine_us_data/calibration/puf_impute.py index 445bd758..dfdada5f 100644 --- a/policyengine_us_data/calibration/puf_impute.py +++ b/policyengine_us_data/calibration/puf_impute.py @@ -194,7 +194,9 @@ "social_security", ] -RETIREMENT_PREDICTORS = RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS +RETIREMENT_PREDICTORS = ( + RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS +) def _get_retirement_limits(year: int) -> dict: @@ -409,7 +411,9 @@ def reconcile_ss_subcomponents( if puf_has_ss.any(): shares = _qrf_ss_shares(data, n_cps, time_period, puf_has_ss) if shares is None: - shares = _age_heuristic_ss_shares(data, n_cps, time_period, puf_has_ss) + shares = _age_heuristic_ss_shares( + data, n_cps, time_period, puf_has_ss + ) for sub in SS_SUBCOMPONENTS: if sub not in data: @@ -488,13 +492,17 @@ def _map_to_entity(pred_values, variable_name): return pred_values entity = var_meta.entity.key if entity != "person": - return cps_sim.populations[entity].value_from_first_person(pred_values) + return cps_sim.populations[entity].value_from_first_person( + pred_values + ) return pred_values # Impute weeks_unemployed for PUF half puf_weeks = None if y_full is not None and dataset_path is not None: - puf_weeks = _impute_weeks_unemployed(data, y_full, time_period, dataset_path) + puf_weeks = _impute_weeks_unemployed( + data, y_full, time_period, dataset_path + ) # Impute retirement contributions for PUF half puf_retirement = None @@ -518,14 +526,24 @@ def _map_to_entity(pred_values, variable_name): time_period: np.concatenate([values, values + values.max()]) } elif "_weight" in variable: - new_data[variable] = {time_period: np.concatenate([values, values * 0])} + new_data[variable] = { + time_period: np.concatenate([values, values * 0]) + } elif variable == "weeks_unemployed" and puf_weeks is not None: - new_data[variable] = {time_period: np.concatenate([values, puf_weeks])} - elif variable in CPS_RETIREMENT_VARIABLES and puf_retirement is not None: + new_data[variable] = { + time_period: np.concatenate([values, puf_weeks]) + } + elif ( + variable in CPS_RETIREMENT_VARIABLES and puf_retirement is not None + ): puf_vals = puf_retirement[variable] - new_data[variable] = {time_period: np.concatenate([values, puf_vals])} + new_data[variable] = { + time_period: np.concatenate([values, puf_vals]) + } else: - new_data[variable] = {time_period: np.concatenate([values, values])} + new_data[variable] = { + time_period: np.concatenate([values, values]) + } new_data["state_fips"] = { time_period: np.concatenate([state_fips, state_fips]).astype(np.int32) @@ -638,7 +656,11 @@ def _impute_weeks_unemployed( logger.info( "Imputed weeks_unemployed for PUF: %d with weeks > 0, mean = %.1f", (imputed_weeks > 0).sum(), - (imputed_weeks[imputed_weeks > 0].mean() if (imputed_weeks > 0).any() else 0), + ( + imputed_weeks[imputed_weeks > 0].mean() + if (imputed_weeks > 0).any() + else 0 + ), ) return imputed_weeks @@ -800,7 +822,9 @@ def _run_qrf_imputation( puf_sim = Microsimulation(dataset=puf_dataset) - puf_agi = puf_sim.calculate("adjusted_gross_income", map_to="person").values + puf_agi = puf_sim.calculate( + "adjusted_gross_income", map_to="person" + ).values X_train_full = puf_sim.calculate_dataframe( DEMOGRAPHIC_PREDICTORS + IMPUTED_VARIABLES @@ -877,7 +901,9 @@ def _stratified_subsample_index( if remaining_quota >= len(bottom_idx): selected_bottom = bottom_idx else: - selected_bottom = rng.choice(bottom_idx, size=remaining_quota, replace=False) + selected_bottom = rng.choice( + bottom_idx, size=remaining_quota, replace=False + ) selected = np.concatenate([top_idx, selected_bottom]) selected.sort() diff --git a/policyengine_us_data/calibration/sanity_checks.py b/policyengine_us_data/calibration/sanity_checks.py index e1f59064..0ea59218 100644 --- a/policyengine_us_data/calibration/sanity_checks.py +++ b/policyengine_us_data/calibration/sanity_checks.py @@ -214,7 +214,9 @@ def _get(f, path): { "check": "per_hh_employment_income", "status": "WARN", - "detail": (f"${per_hh:,.0f}/hh (expected $10K-$200K)"), + "detail": ( + f"${per_hh:,.0f}/hh (expected $10K-$200K)" + ), } ) else: diff --git a/policyengine_us_data/calibration/source_impute.py b/policyengine_us_data/calibration/source_impute.py index 25c7975a..339e038e 100644 --- a/policyengine_us_data/calibration/source_impute.py +++ b/policyengine_us_data/calibration/source_impute.py @@ -225,7 +225,9 @@ def _person_state_fips( if hh_ids_person is not None: hh_ids = data["household_id"][time_period] hh_to_idx = {int(hh_id): i for i, hh_id in enumerate(hh_ids)} - return np.array([state_fips[hh_to_idx[int(hh_id)]] for hh_id in hh_ids_person]) + return np.array( + [state_fips[hh_to_idx[int(hh_id)]] for hh_id in hh_ids_person] + ) # Fallback: distribute persons across households as evenly # as possible (first households get any remainder). n_hh = len(data["household_id"][time_period]) @@ -262,9 +264,9 @@ def _impute_acs( predictors = ACS_PREDICTORS + ["state_fips"] acs_df = acs.calculate_dataframe(ACS_PREDICTORS + ACS_IMPUTED_VARIABLES) - acs_df["state_fips"] = acs.calculate("state_fips", map_to="person").values.astype( - np.float32 - ) + acs_df["state_fips"] = acs.calculate( + "state_fips", map_to="person" + ).values.astype(np.float32) train_df = acs_df[acs_df.is_household_head].sample(10_000, random_state=42) train_df = _encode_tenure_type(train_df) @@ -366,10 +368,16 @@ def _impute_sipp( sipp_df["is_under_18"] = sipp_df.TAGE < 18 sipp_df["is_under_6"] = sipp_df.TAGE < 6 sipp_df["count_under_18"] = ( - sipp_df.groupby("SSUID")["is_under_18"].sum().loc[sipp_df.SSUID.values].values + sipp_df.groupby("SSUID")["is_under_18"] + .sum() + .loc[sipp_df.SSUID.values] + .values ) sipp_df["count_under_6"] = ( - sipp_df.groupby("SSUID")["is_under_6"].sum().loc[sipp_df.SSUID.values].values + sipp_df.groupby("SSUID")["is_under_6"] + .sum() + .loc[sipp_df.SSUID.values] + .values ) tip_cols = [ @@ -400,9 +408,9 @@ def _impute_sipp( age_df = pd.DataFrame({"hh": hh_ids_person, "age": person_ages}) under_18 = age_df.groupby("hh")["age"].apply(lambda x: (x < 18).sum()) under_6 = age_df.groupby("hh")["age"].apply(lambda x: (x < 6).sum()) - cps_tip_df["count_under_18"] = under_18.loc[hh_ids_person].values.astype( - np.float32 - ) + cps_tip_df["count_under_18"] = under_18.loc[ + hh_ids_person + ].values.astype(np.float32) cps_tip_df["count_under_6"] = under_6.loc[hh_ids_person].values.astype( np.float32 ) @@ -491,7 +499,10 @@ def _impute_sipp( asset_train.index, size=min(20_000, len(asset_train)), replace=True, - p=(asset_train.household_weight / asset_train.household_weight.sum()), + p=( + asset_train.household_weight + / asset_train.household_weight.sum() + ), ) ] @@ -502,15 +513,15 @@ def _impute_sipp( ["employment_income", "age", "is_male"], ) if "is_male" in cps_asset_df.columns: - cps_asset_df["is_female"] = (~cps_asset_df["is_male"].astype(bool)).astype( - np.float32 - ) + cps_asset_df["is_female"] = ( + ~cps_asset_df["is_male"].astype(bool) + ).astype(np.float32) else: cps_asset_df["is_female"] = 0.0 if "is_married" in data: - cps_asset_df["is_married"] = data["is_married"][time_period].astype( - np.float32 - ) + cps_asset_df["is_married"] = data["is_married"][ + time_period + ].astype(np.float32) else: cps_asset_df["is_married"] = 0.0 cps_asset_df["count_under_18"] = ( @@ -612,7 +623,9 @@ def _impute_scf( cps_df = _build_cps_receiver(data, time_period, dataset_path, pe_vars) if "is_male" in cps_df.columns: - cps_df["is_female"] = (~cps_df["is_male"].astype(bool)).astype(np.float32) + cps_df["is_female"] = (~cps_df["is_male"].astype(bool)).astype( + np.float32 + ) else: cps_df["is_female"] = 0.0 diff --git a/policyengine_us_data/calibration/stacked_dataset_builder.py b/policyengine_us_data/calibration/stacked_dataset_builder.py index 0089f0d1..172f05fb 100644 --- a/policyengine_us_data/calibration/stacked_dataset_builder.py +++ b/policyengine_us_data/calibration/stacked_dataset_builder.py @@ -105,7 +105,9 @@ f"{geography.n_records} records" ) - print(f"Geography: {geography.n_clones} clones x {geography.n_records} records") + print( + f"Geography: {geography.n_clones} clones x {geography.n_records} records" + ) takeup_filter = [spec["variable"] for spec in SIMPLE_TAKEUP_VARS] diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 66bc1f9b..361f0dba 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -136,7 +136,9 @@ def check_package_staleness(metadata: dict) -> None: built_dt = datetime.datetime.fromisoformat(created) age = datetime.datetime.now() - built_dt if age.days > 7: - print(f"WARNING: Package is {age.days} days old (built {created})") + print( + f"WARNING: Package is {age.days} days old (built {created})" + ) except Exception: pass @@ -169,7 +171,9 @@ def check_package_staleness(metadata: dict) -> None: def parse_args(argv=None): - parser = argparse.ArgumentParser(description="Unified L0 calibration pipeline") + parser = argparse.ArgumentParser( + description="Unified L0 calibration pipeline" + ) parser.add_argument( "--dataset", default=None, @@ -338,7 +342,9 @@ def _match_rules(targets_df, rules): for rule in rules: rule_mask = targets_df["variable"] == rule["variable"] if "geo_level" in rule: - rule_mask = rule_mask & (targets_df["geo_level"] == rule["geo_level"]) + rule_mask = rule_mask & ( + targets_df["geo_level"] == rule["geo_level"] + ) if "domain_variable" in rule: rule_mask = rule_mask & ( targets_df["domain_variable"] == rule["domain_variable"] @@ -578,7 +584,9 @@ def fit_l0_weights( import torch - os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + os.environ.setdefault( + "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True" + ) n_total = X_sparse.shape[1] if initial_weights is None: @@ -621,7 +629,9 @@ def _flushed_print(*args, **kwargs): builtins.print = _flushed_print enable_logging = ( - log_freq is not None and log_path is not None and target_names is not None + log_freq is not None + and log_path is not None + and target_names is not None ) if enable_logging: Path(log_path).parent.mkdir(parents=True, exist_ok=True) @@ -658,7 +668,9 @@ def _flushed_print(*args, **kwargs): with torch.no_grad(): y_pred = model.predict(X_sparse).cpu().numpy() - weights_snap = model.get_weights(deterministic=True).cpu().numpy() + weights_snap = ( + model.get_weights(deterministic=True).cpu().numpy() + ) active_w = weights_snap[weights_snap > 0] nz = len(active_w) @@ -702,7 +714,9 @@ def _flushed_print(*args, **kwargs): flush=True, ) - ach_flags = achievable if achievable is not None else [True] * len(targets) + ach_flags = ( + achievable if achievable is not None else [True] * len(targets) + ) with open(log_path, "a") as f: for i in range(len(targets)): est = y_pred[i] @@ -973,7 +987,8 @@ def run_calibration( ) source_path = str( - Path(dataset_path).parent / f"source_imputed_{Path(dataset_path).stem}.h5" + Path(dataset_path).parent + / f"source_imputed_{Path(dataset_path).stem}.h5" ) with h5py.File(source_path, "w") as f: for var, time_dict in data_dict.items(): @@ -1174,7 +1189,9 @@ def main(argv=None): f"Dataset not found: {dataset_path}\n" "Run 'make data' first, or pass --dataset with a valid path." ) - db_path = args.db_path or str(STORAGE_FOLDER / "calibration" / "policy_data.db") + db_path = args.db_path or str( + STORAGE_FOLDER / "calibration" / "policy_data.db" + ) output_path = args.output or str( STORAGE_FOLDER / "calibration" / "calibration_weights.npy" ) @@ -1188,11 +1205,15 @@ def main(argv=None): domain_variables = None if args.domain_variables: - domain_variables = [x.strip() for x in args.domain_variables.split(",")] + domain_variables = [ + x.strip() for x in args.domain_variables.split(",") + ] hierarchical_domains = None if args.hierarchical_domains: - hierarchical_domains = [x.strip() for x in args.hierarchical_domains.split(",")] + hierarchical_domains = [ + x.strip() for x in args.hierarchical_domains.split(",") + ] t_start = time.time() diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index 04d785ff..de80d015 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -124,7 +124,9 @@ def _compute_single_state( if rerandomize_takeup: for spec in SIMPLE_TAKEUP_VARS: entity = spec["entity"] - n_ent = len(state_sim.calculate(f"{entity}_id", map_to=entity).values) + n_ent = len( + state_sim.calculate(f"{entity}_id", map_to=entity).values + ) state_sim.set_input( spec["variable"], time_period, @@ -250,7 +252,9 @@ def _compute_single_state_group_counties( if rerandomize_takeup: for spec in SIMPLE_TAKEUP_VARS: entity = spec["entity"] - n_ent = len(state_sim.calculate(f"{entity}_id", map_to=entity).values) + n_ent = len( + state_sim.calculate(f"{entity}_id", map_to=entity).values + ) state_sim.set_input( spec["variable"], time_period, @@ -323,7 +327,9 @@ def _assemble_clone_values_standalone( state_masks = {int(s): clone_states == s for s in unique_clone_states} unique_person_states = np.unique(person_states) - person_state_masks = {int(s): person_states == s for s in unique_person_states} + person_state_masks = { + int(s): person_states == s for s in unique_person_states + } county_masks = {} unique_counties = None if clone_counties is not None and county_values: @@ -740,10 +746,18 @@ def _build_entity_relationship(self, sim) -> pd.DataFrame: self._entity_rel_cache = pd.DataFrame( { - "person_id": sim.calculate("person_id", map_to="person").values, - "household_id": sim.calculate("household_id", map_to="person").values, - "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, - "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, + "person_id": sim.calculate( + "person_id", map_to="person" + ).values, + "household_id": sim.calculate( + "household_id", map_to="person" + ).values, + "tax_unit_id": sim.calculate( + "tax_unit_id", map_to="person" + ).values, + "spm_unit_id": sim.calculate( + "spm_unit_id", map_to="person" + ).values, } ) return self._entity_rel_cache @@ -863,7 +877,9 @@ def _build_state_values( except Exception as exc: for f in futures: f.cancel() - raise RuntimeError(f"State {st} failed: {exc}") from exc + raise RuntimeError( + f"State {st} failed: {exc}" + ) from exc else: from policyengine_us import Microsimulation from policyengine_us_data.utils.takeup import ( @@ -919,7 +935,9 @@ def _build_state_values( for spec in SIMPLE_TAKEUP_VARS: entity = spec["entity"] n_ent = len( - state_sim.calculate(f"{entity}_id", map_to=entity).values + state_sim.calculate( + f"{entity}_id", map_to=entity + ).values ) state_sim.set_input( spec["variable"], @@ -1102,7 +1120,9 @@ def _build_county_values( except Exception as exc: for f in futures: f.cancel() - raise RuntimeError(f"State group {sf} failed: {exc}") from exc + raise RuntimeError( + f"State group {sf} failed: {exc}" + ) from exc else: from policyengine_us import Microsimulation from policyengine_us_data.utils.takeup import ( @@ -1278,7 +1298,9 @@ def _assemble_clone_values( # Pre-compute masks to avoid recomputing per variable state_masks = {int(s): clone_states == s for s in unique_clone_states} unique_person_states = np.unique(person_states) - person_state_masks = {int(s): person_states == s for s in unique_person_states} + person_state_masks = { + int(s): person_states == s for s in unique_person_states + } county_masks = {} unique_counties = None if clone_counties is not None and county_values: @@ -1291,7 +1313,9 @@ def _assemble_clone_values( continue if var in cdv and county_values and clone_counties is not None: first_county = unique_counties[0] - if var not in county_values.get(first_county, {}).get("hh", {}): + if var not in county_values.get(first_county, {}).get( + "hh", {} + ): continue arr = np.empty(n_records, dtype=np.float32) for county in unique_counties: @@ -1433,7 +1457,9 @@ def _calculate_uprating_factors(self, params) -> dict: factors[(from_year, "cpi")] = 1.0 try: - pop_from = params.calibration.gov.census.populations.total(from_year) + pop_from = params.calibration.gov.census.populations.total( + from_year + ) pop_to = params.calibration.gov.census.populations.total( self.time_period ) @@ -1510,7 +1536,9 @@ def _get_state_uprating_factors( var_factors[var] = 1.0 continue period = row.iloc[0]["period"] - factor, _ = self._get_uprating_info(var, period, national_factors) + factor, _ = self._get_uprating_info( + var, period, national_factors + ) var_factors[var] = factor result[state_int] = var_factors @@ -1645,7 +1673,9 @@ def _make_target_name( non_geo = [c for c in constraints if c["variable"] not in _GEO_VARS] if non_geo: - strs = [f"{c['variable']}{c['operation']}{c['value']}" for c in non_geo] + strs = [ + f"{c['variable']}{c['operation']}{c['value']}" for c in non_geo + ] parts.append("[" + ",".join(strs) + "]") return "/".join(parts) @@ -1789,9 +1819,15 @@ def build_matrix( n_targets = len(targets_df) # 2. Sort targets by geographic level - targets_df["_geo_level"] = targets_df["geographic_id"].apply(get_geo_level) - targets_df = targets_df.sort_values(["_geo_level", "variable", "geographic_id"]) - targets_df = targets_df.drop(columns=["_geo_level"]).reset_index(drop=True) + targets_df["_geo_level"] = targets_df["geographic_id"].apply( + get_geo_level + ) + targets_df = targets_df.sort_values( + ["_geo_level", "variable", "geographic_id"] + ) + targets_df = targets_df.drop(columns=["_geo_level"]).reset_index( + drop=True + ) # 3. Build column index structures from geography state_col_lists: Dict[int, list] = defaultdict(list) @@ -1818,7 +1854,9 @@ def build_matrix( geo_id = row["geographic_id"] target_geo_info.append((geo_level, geo_id)) - non_geo = [c for c in constraints if c["variable"] not in _GEO_VARS] + non_geo = [ + c for c in constraints if c["variable"] not in _GEO_VARS + ] non_geo_constraints_list.append(non_geo) target_names.append( @@ -1857,10 +1895,14 @@ def build_matrix( # 5c. State-independent structures (computed once) entity_rel = self._build_entity_relationship(sim) - household_ids = sim.calculate("household_id", map_to="household").values + household_ids = sim.calculate( + "household_id", map_to="household" + ).values person_hh_ids = sim.calculate("household_id", map_to="person").values hh_id_to_idx = {int(hid): idx for idx, hid in enumerate(household_ids)} - person_hh_indices = np.array([hh_id_to_idx[int(hid)] for hid in person_hh_ids]) + person_hh_indices = np.array( + [hh_id_to_idx[int(hid)] for hid in person_hh_ids] + ) tax_benefit_system = sim.tax_benefit_system # Pre-extract entity keys so workers don't need @@ -1868,7 +1910,9 @@ def build_matrix( variable_entity_map: Dict[str, str] = {} for var in unique_variables: if var.endswith("_count") and var in tax_benefit_system.variables: - variable_entity_map[var] = tax_benefit_system.variables[var].entity.key + variable_entity_map[var] = tax_benefit_system.variables[ + var + ].entity.key # 5c-extra: Entity-to-household index maps for takeup affected_target_info = {} @@ -1883,7 +1927,9 @@ def build_matrix( # Build entity-to-household index arrays spm_to_hh_id = ( - entity_rel.groupby("spm_unit_id")["household_id"].first().to_dict() + entity_rel.groupby("spm_unit_id")["household_id"] + .first() + .to_dict() ) spm_ids = sim.calculate("spm_unit_id", map_to="spm_unit").values spm_hh_idx = np.array( @@ -1891,7 +1937,9 @@ def build_matrix( ) tu_to_hh_id = ( - entity_rel.groupby("tax_unit_id")["household_id"].first().to_dict() + entity_rel.groupby("tax_unit_id")["household_id"] + .first() + .to_dict() ) tu_ids = sim.calculate("tax_unit_id", map_to="tax_unit").values tu_hh_idx = np.array( @@ -1910,7 +1958,9 @@ def build_matrix( f"{entity_level}_id", map_to=entity_level, ).values - ent_id_to_idx = {int(eid): idx for idx, eid in enumerate(ent_ids)} + ent_id_to_idx = { + int(eid): idx for idx, eid in enumerate(ent_ids) + } person_ent_ids = entity_rel[f"{entity_level}_id"].values entity_to_person_idx[entity_level] = np.array( [ent_id_to_idx[int(eid)] for eid in person_ent_ids] @@ -1933,7 +1983,9 @@ def build_matrix( for tvar, info in affected_target_info.items(): rk = info["rate_key"] if rk not in precomputed_rates: - precomputed_rates[rk] = load_take_up_rate(rk, self.time_period) + precomputed_rates[rk] = load_take_up_rate( + rk, self.time_period + ) # Store for post-optimization stacked takeup self.entity_hh_idx_map = entity_hh_idx_map @@ -2034,7 +2086,9 @@ def build_matrix( except Exception as exc: for f in futures: f.cancel() - raise RuntimeError(f"Clone {ci} failed: {exc}") from exc + raise RuntimeError( + f"Clone {ci} failed: {exc}" + ) from exc else: # ---- Sequential clone processing (unchanged) ---- @@ -2101,7 +2155,9 @@ def build_matrix( ent_counties = clone_counties[ent_hh] for cfips in np.unique(ent_counties): m = ent_counties == cfips - cv = county_values.get(cfips, {}).get("entity", {}) + cv = county_values.get(cfips, {}).get( + "entity", {} + ) if tvar in cv: ent_eligible[m] = cv[tvar][m] else: @@ -2126,7 +2182,9 @@ def build_matrix( ent_hh_ids, ) - ent_values = (ent_eligible * ent_takeup).astype(np.float32) + ent_values = (ent_eligible * ent_takeup).astype( + np.float32 + ) hh_result = np.zeros(n_records, dtype=np.float32) np.add.at(hh_result, ent_hh, ent_values) @@ -2186,15 +2244,17 @@ def build_matrix( constraint_key, ) if vkey not in count_cache: - count_cache[vkey] = _calculate_target_values_standalone( - target_variable=variable, - non_geo_constraints=non_geo, - n_households=n_records, - hh_vars=hh_vars, - person_vars=person_vars, - entity_rel=entity_rel, - household_ids=household_ids, - variable_entity_map=variable_entity_map, + count_cache[vkey] = ( + _calculate_target_values_standalone( + target_variable=variable, + non_geo_constraints=non_geo, + n_households=n_records, + hh_vars=hh_vars, + person_vars=person_vars, + entity_rel=entity_rel, + household_ids=household_ids, + variable_entity_map=variable_entity_map, + ) ) values = count_cache[vkey] else: diff --git a/policyengine_us_data/calibration/validate_national_h5.py b/policyengine_us_data/calibration/validate_national_h5.py index cbe22796..ba303812 100644 --- a/policyengine_us_data/calibration/validate_national_h5.py +++ b/policyengine_us_data/calibration/validate_national_h5.py @@ -145,9 +145,7 @@ def main(argv=None): icon = ( "PASS" if r["status"] == "PASS" - else "FAIL" - if r["status"] == "FAIL" - else "WARN" + else "FAIL" if r["status"] == "FAIL" else "WARN" ) print(f" [{icon}] {r['check']}: {r['detail']}") diff --git a/policyengine_us_data/calibration/validate_package.py b/policyengine_us_data/calibration/validate_package.py index c8ed16bc..4321fbf8 100644 --- a/policyengine_us_data/calibration/validate_package.py +++ b/policyengine_us_data/calibration/validate_package.py @@ -85,7 +85,9 @@ def validate_package( ) k = min(n_hardest, len(ratios)) hardest_local_idx = np.argpartition(ratios, k)[:k] - hardest_local_idx = hardest_local_idx[np.argsort(ratios[hardest_local_idx])] + hardest_local_idx = hardest_local_idx[ + np.argsort(ratios[hardest_local_idx]) + ] hardest_global_idx = achievable_idx[hardest_local_idx] hardest_targets = pd.DataFrame( @@ -94,7 +96,9 @@ def validate_package( "domain_variable": targets_df["domain_variable"] .iloc[hardest_global_idx] .values, - "variable": targets_df["variable"].iloc[hardest_global_idx].values, + "variable": targets_df["variable"] + .iloc[hardest_global_idx] + .values, "geographic_id": targets_df["geographic_id"] .iloc[hardest_global_idx] .values, @@ -186,7 +190,9 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: lines.append(", ".join(parts)) lines.append("") - pct = 100 * result.n_achievable / result.n_targets if result.n_targets else 0 + pct = ( + 100 * result.n_achievable / result.n_targets if result.n_targets else 0 + ) pct_imp = 100 - pct lines.append("--- Achievability ---") lines.append( @@ -200,7 +206,9 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: if len(result.impossible_targets) > 0: lines.append("--- Impossible Targets ---") for _, row in result.impossible_targets.iterrows(): - lines.append(f" {row['target_name']:<60s} {row['target_value']:>14,.0f}") + lines.append( + f" {row['target_name']:<60s} {row['target_value']:>14,.0f}" + ) lines.append("") if len(result.impossible_by_group) > 1: @@ -257,7 +265,9 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: f" targets below ratio {result.strict_ratio})" ) elif result.n_impossible > 0: - lines.append(f"RESULT: FAIL ({result.n_impossible} impossible targets)") + lines.append( + f"RESULT: FAIL ({result.n_impossible} impossible targets)" + ) else: lines.append("RESULT: PASS") @@ -265,7 +275,9 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: def main(): - parser = argparse.ArgumentParser(description="Validate a calibration package") + parser = argparse.ArgumentParser( + description="Validate a calibration package" + ) parser.add_argument( "path", nargs="?", diff --git a/policyengine_us_data/calibration/validate_staging.py b/policyengine_us_data/calibration/validate_staging.py index be2f908d..4ecea143 100644 --- a/policyengine_us_data/calibration/validate_staging.py +++ b/policyengine_us_data/calibration/validate_staging.py @@ -178,7 +178,9 @@ def _batch_stratum_constraints(engine, stratum_ids) -> dict: df = pd.read_sql(query, conn) result = {} for sid, group in df.groupby("stratum_id"): - result[int(sid)] = group[["variable", "operation", "value"]].to_dict("records") + result[int(sid)] = group[["variable", "operation", "value"]].to_dict( + "records" + ) for sid in stratum_ids: result.setdefault(int(sid), []) return result @@ -262,9 +264,15 @@ def _build_entity_rel(sim) -> pd.DataFrame: return pd.DataFrame( { "person_id": sim.calculate("person_id", map_to="person").values, - "household_id": sim.calculate("household_id", map_to="person").values, - "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, - "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, + "household_id": sim.calculate( + "household_id", map_to="person" + ).values, + "tax_unit_id": sim.calculate( + "tax_unit_id", map_to="person" + ).values, + "spm_unit_id": sim.calculate( + "spm_unit_id", map_to="person" + ).values, } ) @@ -704,7 +712,9 @@ def _run_state_via_districts( variable = row_data["variable"] stratum_id = int(row_data["stratum_id"]) constraints = constraints_map.get(stratum_id, []) - target_name = UnifiedMatrixBuilder._make_target_name(variable, constraints) + target_name = UnifiedMatrixBuilder._make_target_name( + variable, constraints + ) per_district_rows.append( { @@ -735,7 +745,9 @@ def _run_state_via_districts( stratum_id = int(row_data["stratum_id"]) constraints = constraints_map.get(stratum_id, []) - target_name = UnifiedMatrixBuilder._make_target_name(variable, constraints) + target_name = UnifiedMatrixBuilder._make_target_name( + variable, constraints + ) error = sim_value - target_value abs_error = abs(error) @@ -746,7 +758,9 @@ def _run_state_via_districts( rel_error = float("inf") if error != 0 else 0.0 rel_abs_error = float("inf") if abs_error != 0 else 0.0 - sanity_check, sanity_reason = _run_sanity_check(sim_value, variable, "state") + sanity_check, sanity_reason = _run_sanity_check( + sim_value, variable, "state" + ) summary_rows.append( { diff --git a/policyengine_us_data/datasets/acs/acs.py b/policyengine_us_data/datasets/acs/acs.py index 11d1ef73..0ecd3ee7 100644 --- a/policyengine_us_data/datasets/acs/acs.py +++ b/policyengine_us_data/datasets/acs/acs.py @@ -18,7 +18,9 @@ def generate(self) -> None: raw_data = self.census_acs(require=True).load() acs = h5py.File(self.file_path, mode="w") - person, household = [raw_data[entity] for entity in ("person", "household")] + person, household = [ + raw_data[entity] for entity in ("person", "household") + ] self.add_id_variables(acs, person, household) self.add_person_variables(acs, person, household) @@ -37,7 +39,9 @@ def add_id_variables( h_id_to_number = pd.Series( np.arange(len(household)), index=household["SERIALNO"] ) - household["household_id"] = h_id_to_number[household["SERIALNO"]].values + household["household_id"] = h_id_to_number[ + household["SERIALNO"] + ].values person["household_id"] = h_id_to_number[person["SERIALNO"]].values person["person_id"] = person.index + 1 @@ -96,7 +100,9 @@ def add_spm_variables(acs: h5py.File, spm_unit: DataFrame) -> None: @staticmethod def add_household_variables(acs: h5py.File, household: DataFrame) -> None: acs["household_vehicles_owned"] = household.VEH - acs["state_fips"] = acs["household_state_fips"] = household.ST.astype(int) + acs["state_fips"] = acs["household_state_fips"] = household.ST.astype( + int + ) class ACS_2022(ACS): diff --git a/policyengine_us_data/datasets/acs/census_acs.py b/policyengine_us_data/datasets/acs/census_acs.py index 7bd28bd6..842af627 100644 --- a/policyengine_us_data/datasets/acs/census_acs.py +++ b/policyengine_us_data/datasets/acs/census_acs.py @@ -66,7 +66,9 @@ def generate(self) -> None: household = self.process_household_data( household_url, "psam_hus", HOUSEHOLD_COLUMNS ) - person = self.process_person_data(person_url, "psam_pus", PERSON_COLUMNS) + person = self.process_person_data( + person_url, "psam_pus", PERSON_COLUMNS + ) person = person[person.SERIALNO.isin(household.SERIALNO)] household = household[household.SERIALNO.isin(person.SERIALNO)] storage["household"] = household @@ -104,7 +106,9 @@ def process_household_data( return res @staticmethod - def process_person_data(url: str, prefix: str, columns: List[str]) -> pd.DataFrame: + def process_person_data( + url: str, prefix: str, columns: List[str] + ) -> pd.DataFrame: req = requests.get(url, stream=True) with BytesIO() as f: pbar = tqdm() @@ -133,7 +137,9 @@ def process_person_data(url: str, prefix: str, columns: List[str]) -> pd.DataFra return res @staticmethod - def create_spm_unit_table(storage: pd.HDFStore, person: pd.DataFrame) -> None: + def create_spm_unit_table( + storage: pd.HDFStore, person: pd.DataFrame + ) -> None: SPM_UNIT_COLUMNS = [ "CAPHOUSESUB", "CAPWKCCXPNS", @@ -175,10 +181,12 @@ def create_spm_unit_table(storage: pd.HDFStore, person: pd.DataFrame) -> None: # Ensure SERIALNO is treated as string JOIN_COLUMNS = ["SERIALNO", "SPORDER"] - original_person_table["SERIALNO"] = original_person_table["SERIALNO"].astype( - str - ) - original_person_table["SPORDER"] = original_person_table["SPORDER"].astype(int) + original_person_table["SERIALNO"] = original_person_table[ + "SERIALNO" + ].astype(str) + original_person_table["SPORDER"] = original_person_table[ + "SPORDER" + ].astype(int) person["SERIALNO"] = person["SERIALNO"].astype(str) person["SPORDER"] = person["SPORDER"].astype(int) diff --git a/policyengine_us_data/datasets/cps/census_cps.py b/policyengine_us_data/datasets/cps/census_cps.py index 042fefe5..00ca020e 100644 --- a/policyengine_us_data/datasets/cps/census_cps.py +++ b/policyengine_us_data/datasets/cps/census_cps.py @@ -15,7 +15,9 @@ class CensusCPS(Dataset): def generate(self): if self._cps_download_url is None: - raise ValueError(f"No raw CPS data URL known for year {self.time_period}.") + raise ValueError( + f"No raw CPS data URL known for year {self.time_period}." + ) url = self._cps_download_url @@ -26,7 +28,9 @@ def generate(self): ] response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get("content-length", 200e6)) + total_size_in_bytes = int( + response.headers.get("content-length", 200e6) + ) progress_bar = tqdm( total=total_size_in_bytes, unit="iB", @@ -34,7 +38,9 @@ def generate(self): desc="Downloading ASEC", ) if response.status_code == 404: - raise FileNotFoundError("Received a 404 response when fetching the data.") + raise FileNotFoundError( + "Received a 404 response when fetching the data." + ) with BytesIO() as file: content_length_actual = 0 for data in response.iter_content(int(1e6)): @@ -59,23 +65,33 @@ def generate(self): file_prefix = "cpspb/asec/prod/data/2019/" else: file_prefix = "" - with zipfile.open(f"{file_prefix}pppub{file_year_code}.csv") as f: + with zipfile.open( + f"{file_prefix}pppub{file_year_code}.csv" + ) as f: storage["person"] = pd.read_csv( f, - usecols=PERSON_COLUMNS + spm_unit_columns + TAX_UNIT_COLUMNS, + usecols=PERSON_COLUMNS + + spm_unit_columns + + TAX_UNIT_COLUMNS, ).fillna(0) person = storage["person"] - with zipfile.open(f"{file_prefix}ffpub{file_year_code}.csv") as f: + with zipfile.open( + f"{file_prefix}ffpub{file_year_code}.csv" + ) as f: person_family_id = person.PH_SEQ * 10 + person.PF_SEQ family = pd.read_csv(f).fillna(0) family_id = family.FH_SEQ * 10 + family.FFPOS family = family[family_id.isin(person_family_id)] storage["family"] = family - with zipfile.open(f"{file_prefix}hhpub{file_year_code}.csv") as f: + with zipfile.open( + f"{file_prefix}hhpub{file_year_code}.csv" + ) as f: person_household_id = person.PH_SEQ household = pd.read_csv(f).fillna(0) household_id = household.H_SEQ - household = household[household_id.isin(person_household_id)] + household = household[ + household_id.isin(person_household_id) + ] storage["household"] = household storage["tax_unit"] = self._create_tax_unit_table(person) storage["spm_unit"] = self._create_spm_unit_table( diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index 418d7396..3ec1f769 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -93,7 +93,9 @@ def downsample(self, frac: float): # Store original dtypes before modifying original_data: dict = self.load_dataset() - original_dtypes = {key: original_data[key].dtype for key in original_data} + original_dtypes = { + key: original_data[key].dtype for key in original_data + } sim = Microsimulation(dataset=self) sim.subsample(frac=frac) @@ -206,13 +208,18 @@ def add_takeup(self): aca_rate = load_take_up_rate("aca", self.time_period) medicaid_rates_by_state = load_take_up_rate("medicaid", self.time_period) head_start_rate = load_take_up_rate("head_start", self.time_period) - early_head_start_rate = load_take_up_rate("early_head_start", self.time_period) + early_head_start_rate = load_take_up_rate( + "early_head_start", self.time_period + ) ssi_rate = load_take_up_rate("ssi", self.time_period) # EITC: varies by number of children eitc_child_count = baseline.calculate("eitc_child_count").values eitc_takeup_rate = np.array( - [eitc_rates_by_children.get(min(int(c), 3), 0.85) for c in eitc_child_count] + [ + eitc_rates_by_children.get(min(int(c), 3), 0.85) + for c in eitc_child_count + ] ) rng = seeded_rng("takes_up_eitc") data["takes_up_eitc"] = rng.random(n_tax_units) < eitc_takeup_rate @@ -231,7 +238,9 @@ def add_takeup(self): target_snap_takeup_count = int(snap_rate * n_spm_units) remaining_snap_needed = max(0, target_snap_takeup_count - n_snap_reporters) snap_non_reporter_rate = ( - remaining_snap_needed / n_snap_non_reporters if n_snap_non_reporters > 0 else 0 + remaining_snap_needed / n_snap_non_reporters + if n_snap_non_reporters > 0 + else 0 ) # Assign: all reporters + adjusted rate for non-reporters @@ -248,7 +257,9 @@ def add_takeup(self): hh_ids = data["household_id"] person_hh_ids = data["person_household_id"] hh_to_state = dict(zip(hh_ids, state_codes)) - person_states = np.array([hh_to_state.get(hh_id, "CA") for hh_id in person_hh_ids]) + person_states = np.array( + [hh_to_state.get(hh_id, "CA") for hh_id in person_hh_ids] + ) medicaid_rate_by_person = np.array( [medicaid_rates_by_state.get(s, 0.93) for s in person_states] ) @@ -259,7 +270,9 @@ def add_takeup(self): # Head Start rng = seeded_rng("takes_up_head_start_if_eligible") - data["takes_up_head_start_if_eligible"] = rng.random(n_persons) < head_start_rate + data["takes_up_head_start_if_eligible"] = ( + rng.random(n_persons) < head_start_rate + ) # Early Head Start rng = seeded_rng("takes_up_early_head_start_if_eligible") @@ -277,7 +290,9 @@ def add_takeup(self): target_ssi_takeup_count = int(ssi_rate * n_persons) remaining_ssi_needed = max(0, target_ssi_takeup_count - n_ssi_reporters) ssi_non_reporter_rate = ( - remaining_ssi_needed / n_ssi_non_reporters if n_ssi_non_reporters > 0 else 0 + remaining_ssi_needed / n_ssi_non_reporters + if n_ssi_non_reporters > 0 + else 0 ) # Assign: all reporters + adjusted rate for non-reporters @@ -300,7 +315,9 @@ def add_takeup(self): data["would_claim_wic"] = rng.random(n_persons) < wic_takeup_rate_by_person # WIC nutritional risk — fully resolved - wic_risk_rates = load_take_up_rate("wic_nutritional_risk", self.time_period) + wic_risk_rates = load_take_up_rate( + "wic_nutritional_risk", self.time_period + ) wic_risk_rate_by_person = np.array( [wic_risk_rates.get(c, 0) for c in wic_categories] ) @@ -347,8 +364,12 @@ def uprate_cps_data(data, from_period, to_period): uprating = create_policyengine_uprating_factors_table() for variable in uprating.index.unique(): if variable in data: - current_index = uprating[uprating.index == variable][to_period].values[0] - start_index = uprating[uprating.index == variable][from_period].values[0] + current_index = uprating[uprating.index == variable][ + to_period + ].values[0] + start_index = uprating[uprating.index == variable][ + from_period + ].values[0] growth = current_index / start_index data[variable] = data[variable] * growth @@ -390,7 +411,9 @@ def add_id_variables( # Marital units - marital_unit_id = person.PH_SEQ * 1e6 + np.maximum(person.A_LINENO, person.A_SPOUSE) + marital_unit_id = person.PH_SEQ * 1e6 + np.maximum( + person.A_LINENO, person.A_SPOUSE + ) # marital_unit_id is not the household ID, zero padded and followed # by the index within household (of each person, or their spouse if @@ -430,7 +453,9 @@ def add_personal_variables(cps: h5py.File, person: DataFrame) -> None: # "Is...blind or does...have serious difficulty seeing even when Wearing # glasses?" 1 -> Yes cps["is_blind"] = person.PEDISEYE == 1 - DISABILITY_FLAGS = ["PEDIS" + i for i in ["DRS", "EAR", "EYE", "OUT", "PHY", "REM"]] + DISABILITY_FLAGS = [ + "PEDIS" + i for i in ["DRS", "EAR", "EYE", "OUT", "PHY", "REM"] + ] cps["is_disabled"] = (person[DISABILITY_FLAGS] == 1).any(axis=1) def children_per_parent(col: str) -> pd.DataFrame: @@ -452,7 +477,9 @@ def children_per_parent(col: str) -> pd.DataFrame: # Aggregate to parent. res = ( - pd.concat([children_per_parent("PEPAR1"), children_per_parent("PEPAR2")]) + pd.concat( + [children_per_parent("PEPAR1"), children_per_parent("PEPAR2")] + ) .groupby(["PH_SEQ", "A_LINENO"]) .children.sum() .reset_index() @@ -478,7 +505,9 @@ def children_per_parent(col: str) -> pd.DataFrame: add_overtime_occupation(cps, person) -def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): +def add_personal_income_variables( + cps: h5py.File, person: DataFrame, year: int +): """Add income variables. Args: @@ -504,14 +533,16 @@ def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): cps["weekly_hours_worked"] = person.HRSWK cps["hours_worked_last_week"] = person.A_HRS1 - cps["taxable_interest_income"] = person.INT_VAL * (p["taxable_interest_fraction"]) + cps["taxable_interest_income"] = person.INT_VAL * ( + p["taxable_interest_fraction"] + ) cps["tax_exempt_interest_income"] = person.INT_VAL * ( 1 - p["taxable_interest_fraction"] ) cps["self_employment_income"] = person.SEMP_VAL cps["farm_income"] = person.FRSE_VAL - cps["qualified_dividend_income"] = ( - person.DIV_VAL * (p["qualified_dividend_fraction"]) + cps["qualified_dividend_income"] = person.DIV_VAL * ( + p["qualified_dividend_fraction"] ) cps["non_qualified_dividend_income"] = person.DIV_VAL * ( 1 - p["qualified_dividend_fraction"] @@ -530,14 +561,18 @@ def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): # 8 = Other is_retirement = (person.RESNSS1 == 1) | (person.RESNSS2 == 1) is_disability = (person.RESNSS1 == 2) | (person.RESNSS2 == 2) - is_survivor = np.isin(person.RESNSS1, [3, 5]) | np.isin(person.RESNSS2, [3, 5]) + is_survivor = np.isin(person.RESNSS1, [3, 5]) | np.isin( + person.RESNSS2, [3, 5] + ) is_dependent = np.isin(person.RESNSS1, [4, 6, 7]) | np.isin( person.RESNSS2, [4, 6, 7] ) # Primary classification: assign full SS_VAL to the highest- # priority category when someone has multiple source codes. - cps["social_security_retirement"] = np.where(is_retirement, person.SS_VAL, 0) + cps["social_security_retirement"] = np.where( + is_retirement, person.SS_VAL, 0 + ) cps["social_security_disability"] = np.where( is_disability & ~is_retirement, person.SS_VAL, 0 ) @@ -580,7 +615,9 @@ def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): # Add pensions and annuities. cps_pensions = person.PNSN_VAL + person.ANN_VAL # Assume a constant fraction of pension income is taxable. - cps["taxable_private_pension_income"] = cps_pensions * p["taxable_pension_fraction"] + cps["taxable_private_pension_income"] = ( + cps_pensions * p["taxable_pension_fraction"] + ) cps["tax_exempt_private_pension_income"] = cps_pensions * ( 1 - p["taxable_pension_fraction"] ) @@ -604,11 +641,18 @@ def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): for source_with_taxable_fraction in ["401k", "403b", "sep"]: cps[f"taxable_{source_with_taxable_fraction}_distributions"] = ( cps[f"{source_with_taxable_fraction}_distributions"] - * p[f"taxable_{source_with_taxable_fraction}_distribution_fraction"] + * p[ + f"taxable_{source_with_taxable_fraction}_distribution_fraction" + ] ) cps[f"tax_exempt_{source_with_taxable_fraction}_distributions"] = cps[ f"{source_with_taxable_fraction}_distributions" - ] * (1 - p[f"taxable_{source_with_taxable_fraction}_distribution_fraction"]) + ] * ( + 1 + - p[ + f"taxable_{source_with_taxable_fraction}_distribution_fraction" + ] + ) del cps[f"{source_with_taxable_fraction}_distributions"] # Assume all regular IRA distributions are taxable, @@ -696,7 +740,9 @@ def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): cps["traditional_ira_contributions"] = ira_capped * trad_ira_share cps["roth_ira_contributions"] = ira_capped * (1 - trad_ira_share) # Allocate capital gains into long-term and short-term based on aggregate split. - cps["long_term_capital_gains"] = person.CAP_VAL * (p["long_term_capgain_fraction"]) + cps["long_term_capital_gains"] = person.CAP_VAL * ( + p["long_term_capgain_fraction"] + ) cps["short_term_capital_gains"] = person.CAP_VAL * ( 1 - p["long_term_capgain_fraction"] ) @@ -724,7 +770,10 @@ def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): # Get QBI simulation parameters --- yamlfilename = ( - files("policyengine_us_data") / "datasets" / "puf" / "qbi_assumptions.yaml" + files("policyengine_us_data") + / "datasets" + / "puf" + / "qbi_assumptions.yaml" ) with open(yamlfilename, "r", encoding="utf-8") as yamlfile: p = yaml.safe_load(yamlfile) @@ -778,10 +827,14 @@ def add_spm_variables(self, cps: h5py.File, spm_unit: DataFrame) -> None: 3: "RENTER", } cps["spm_unit_tenure_type"] = ( - spm_unit.SPM_TENMORTSTATUS.map(tenure_map).fillna("RENTER").astype("S") + spm_unit.SPM_TENMORTSTATUS.map(tenure_map) + .fillna("RENTER") + .astype("S") ) - cps["reduced_price_school_meals_reported"] = cps["free_school_meals_reported"] * 0 + cps["reduced_price_school_meals_reported"] = ( + cps["free_school_meals_reported"] * 0 + ) def add_household_variables(cps: h5py.File, household: DataFrame) -> None: @@ -915,7 +968,9 @@ def select_random_subset_to_target( share_to_move = min(share_to_move, 1.0) # Cap at 100% else: # Calculate how much to move to reach target (for EAD case) - needed_weighted = current_weighted - target_weighted # Will be negative + needed_weighted = ( + current_weighted - target_weighted + ) # Will be negative total_weight = np.sum(person_weights[eligible_ids]) share_to_move = abs(needed_weighted) / total_weight share_to_move = min(share_to_move, 1.0) # Cap at 100% @@ -1159,7 +1214,9 @@ def select_random_subset_to_target( ) # CONDITION 10: Government Employees - is_government_worker = np.isin(person.PEIO1COW, [1, 2, 3]) # Fed/state/local gov + is_government_worker = np.isin( + person.PEIO1COW, [1, 2, 3] + ) # Fed/state/local gov is_military_occupation = person.A_MJOCC == 11 # Military occupation is_government_employee = is_government_worker | is_military_occupation condition_10_mask = potentially_undocumented & is_government_employee @@ -1273,8 +1330,12 @@ def select_random_subset_to_target( undocumented_students_mask = ( (ssn_card_type == 0) & noncitizens & (person.A_HSCOL == 2) ) - undocumented_workers_count = np.sum(person_weights[undocumented_workers_mask]) - undocumented_students_count = np.sum(person_weights[undocumented_students_mask]) + undocumented_workers_count = np.sum( + person_weights[undocumented_workers_mask] + ) + undocumented_students_count = np.sum( + person_weights[undocumented_students_mask] + ) after_conditions_code_0 = np.sum(person_weights[ssn_card_type == 0]) print(f"After conditions - Code 0 people: {after_conditions_code_0:,.0f}") @@ -1469,11 +1530,15 @@ def select_random_subset_to_target( f"Selected {len(selected_indices)} people from {len(mixed_household_candidates)} candidates in mixed households" ) else: - print("No additional family members selected (target already reached)") + print( + "No additional family members selected (target already reached)" + ) else: print("No mixed-status households found for family correlation") else: - print("No additional undocumented people needed - target already reached") + print( + "No additional undocumented people needed - target already reached" + ) # Calculate the weighted impact code_0_after = np.sum(person_weights[ssn_card_type == 0]) @@ -1548,7 +1613,9 @@ def get_arrival_year_midpoint(peinusyr): age_at_entry = np.maximum(0, person.A_AGE - years_in_us) # start every non-citizen as LPR so no UNSET survives - immigration_status = np.full(len(person), "LEGAL_PERMANENT_RESIDENT", dtype="U32") + immigration_status = np.full( + len(person), "LEGAL_PERMANENT_RESIDENT", dtype="U32" + ) # Set citizens (SSN card type 1) to CITIZEN status immigration_status[ssn_card_type == 1] = "CITIZEN" @@ -1596,7 +1663,9 @@ def get_arrival_year_midpoint(peinusyr): immigration_status[recent_refugee_mask] = "REFUGEE" # 6. Temp non-qualified (Code 2 not caught by DACA rule) - mask = (ssn_card_type == 2) & (immigration_status == "LEGAL_PERMANENT_RESIDENT") + mask = (ssn_card_type == 2) & ( + immigration_status == "LEGAL_PERMANENT_RESIDENT" + ) immigration_status[mask] = "TPS" # Final write (all values now in ImmigrationStatus Enum) @@ -1612,7 +1681,9 @@ def get_arrival_year_midpoint(peinusyr): 2: "NON_CITIZEN_VALID_EAD", # Non-citizens with work/study authorization 3: "OTHER_NON_CITIZEN", # Non-citizens with indicators of legal status } - ssn_card_type_str = pd.Series(ssn_card_type).map(code_to_str).astype("S").values + ssn_card_type_str = ( + pd.Series(ssn_card_type).map(code_to_str).astype("S").values + ) cps["ssn_card_type"] = ssn_card_type_str # Final population summary @@ -1819,7 +1890,9 @@ def add_tips(self, cps: h5py.File): # Drop temporary columns used only for imputation # is_married is person-level here but policyengine-us defines it at Family # level, so we must not save it - cps = cps.drop(columns=["is_married", "is_under_18", "is_under_6"], errors="ignore") + cps = cps.drop( + columns=["is_married", "is_under_18", "is_under_6"], errors="ignore" + ) self.save_dataset(cps) @@ -1939,7 +2012,9 @@ def create_scf_reference_person_mask(cps_data, raw_person_data): all_persons_data["is_female"] = (raw_person_data.A_SEX == 2).values # Add marital status (A_MARITL codes: 1,2 = married with spouse present/absent) - all_persons_data["is_married"] = raw_person_data.A_MARITL.isin([1, 2]).values + all_persons_data["is_married"] = raw_person_data.A_MARITL.isin( + [1, 2] + ).values # Define adults as age 18+ all_persons_data["is_adult"] = all_persons_data["age"] >= 18 @@ -1958,7 +2033,8 @@ def create_scf_reference_person_mask(cps_data, raw_person_data): # Identify couple households (households with exactly 2 married adults) married_adults_per_household = ( all_persons_data[ - (all_persons_data["is_adult"]) & (all_persons_data["is_married"]) + (all_persons_data["is_adult"]) + & (all_persons_data["is_married"]) ] .groupby("person_household_id") .size() @@ -1966,7 +2042,12 @@ def create_scf_reference_person_mask(cps_data, raw_person_data): couple_households = married_adults_per_household[ (married_adults_per_household == 2) - & (all_persons_data.groupby("person_household_id")["n_adults"].first() == 2) + & ( + all_persons_data.groupby("person_household_id")[ + "n_adults" + ].first() + == 2 + ) ].index all_persons_data["is_couple_household"] = all_persons_data[ @@ -2066,7 +2147,9 @@ def determine_reference_person(group): } # Apply the mapping to recode the race values - cps_data["cps_race"] = np.vectorize(CPS_RACE_MAPPING.get)(cps_data["cps_race"]) + cps_data["cps_race"] = np.vectorize(CPS_RACE_MAPPING.get)( + cps_data["cps_race"] + ) lengths = {k: len(v) for k, v in cps_data.items()} var_len = cps_data["person_household_id"].shape[0] @@ -2098,7 +2181,9 @@ def determine_reference_person(group): # Add is_married variable for household heads based on raw person data reference_persons = person_data[mask] - receiver_data["is_married"] = reference_persons.A_MARITL.isin([1, 2]).values + receiver_data["is_married"] = reference_persons.A_MARITL.isin( + [1, 2] + ).values # Impute auto loan balance from the SCF from policyengine_us_data.datasets.scf.scf import SCF_2022 @@ -2133,7 +2218,9 @@ def determine_reference_person(group): logging.getLogger("microimpute").setLevel(getattr(logging, log_level)) qrf_model = QRF() - donor_data = donor_data.sample(frac=0.5, random_state=42).reset_index(drop=True) + donor_data = donor_data.sample(frac=0.5, random_state=42).reset_index( + drop=True + ) fitted_model = qrf_model.fit( X_train=donor_data, predictors=PREDICTORS, diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 8755c73e..2b3b46ef 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -46,7 +46,9 @@ def reweight( normalisation_factor = np.where( is_national, nation_normalisation_factor, state_normalisation_factor ) - normalisation_factor = torch.tensor(normalisation_factor, dtype=torch.float32) + normalisation_factor = torch.tensor( + normalisation_factor, dtype=torch.float32 + ) targets_array = torch.tensor(targets_array, dtype=torch.float32) inv_mean_normalisation = 1 / np.mean(normalisation_factor.numpy()) @@ -59,8 +61,12 @@ def loss(weights): estimate = weights @ loss_matrix if torch.isnan(estimate).any(): raise ValueError("Estimate contains NaNs") - rel_error = (((estimate - targets_array) + 1) / (targets_array + 1)) ** 2 - rel_error_normalized = inv_mean_normalisation * rel_error * normalisation_factor + rel_error = ( + ((estimate - targets_array) + 1) / (targets_array + 1) + ) ** 2 + rel_error_normalized = ( + inv_mean_normalisation * rel_error * normalisation_factor + ) if torch.isnan(rel_error_normalized).any(): raise ValueError("Relative error contains NaNs") return rel_error_normalized.mean() @@ -115,7 +121,9 @@ def loss(weights): start_loss = l.item() loss_rel_change = (l.item() - start_loss) / start_loss l.backward() - iterator.set_postfix({"loss": l.item(), "loss_rel_change": loss_rel_change}) + iterator.set_postfix( + {"loss": l.item(), "loss_rel_change": loss_rel_change} + ) optimizer.step() if log_path is not None: performance.to_csv(log_path, index=False) @@ -174,7 +182,9 @@ def generate(self): # Run the optimization procedure to get (close to) minimum loss weights for year in range(self.start_year, self.end_year + 1): - loss_matrix, targets_array = build_loss_matrix(self.input_dataset, year) + loss_matrix, targets_array = build_loss_matrix( + self.input_dataset, year + ) zero_mask = np.isclose(targets_array, 0.0, atol=0.1) bad_mask = loss_matrix.columns.isin(bad_targets) keep_mask_bool = ~(zero_mask | bad_mask) @@ -200,7 +210,9 @@ def generate(self): # Validate dense weights w = optimised_weights if np.any(np.isnan(w)): - raise ValueError(f"Year {year}: household_weight contains NaN values") + raise ValueError( + f"Year {year}: household_weight contains NaN values" + ) if np.any(w < 0): raise ValueError( f"Year {year}: household_weight contains negative values" @@ -241,8 +253,12 @@ def generate(self): 1, 0.1, len(original_weights) ) for year in [2024]: - loss_matrix, targets_array = build_loss_matrix(self.input_dataset, year) - optimised_weights = reweight(original_weights, loss_matrix, targets_array) + loss_matrix, targets_array = build_loss_matrix( + self.input_dataset, year + ) + optimised_weights = reweight( + original_weights, loss_matrix, targets_array + ) data["household_weight"] = optimised_weights self.save_dataset(data) diff --git a/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py b/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py index 5fe3e599..28bdfd3e 100644 --- a/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py +++ b/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py @@ -24,7 +24,8 @@ ## Taxable Payroll for Social Security taxible_estimate_b = ( sim.calculate("taxable_earnings_for_social_security").sum() / 1e9 - + sim.calculate("social_security_taxable_self_employment_income").sum() / 1e9 + + sim.calculate("social_security_taxable_self_employment_income").sum() + / 1e9 ) ### Trustees SingleYearTRTables_TR2025.xlsx, Tab VI.G6 (nominal dollars in billions) @@ -65,7 +66,8 @@ ## Taxable Payroll for Social Security taxible_estimate_b = ( sim.calculate("taxable_earnings_for_social_security").sum() / 1e9 - + sim.calculate("social_security_taxable_self_employment_income").sum() / 1e9 + + sim.calculate("social_security_taxable_self_employment_income").sum() + / 1e9 ) ### Trustees SingleYearTRTables_TR2025.xlsx, Tab VI.G6 (nominal dollars in billions) @@ -173,9 +175,9 @@ def create_h6_reform(): # The swapped rate error is 14x smaller and aligns with tax-cutting intent. # Tier 1 (Base): HI ONLY (35%) - reform_payload["gov.irs.social_security.taxability.rate.base.benefit_cap"][ - period - ] = 0.35 + reform_payload[ + "gov.irs.social_security.taxability.rate.base.benefit_cap" + ][period] = 0.35 reform_payload["gov.irs.social_security.taxability.rate.base.excess"][ period ] = 0.35 @@ -184,25 +186,25 @@ def create_h6_reform(): reform_payload[ "gov.irs.social_security.taxability.rate.additional.benefit_cap" ][period] = 0.85 - reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ - period - ] = 0.85 + reform_payload[ + "gov.irs.social_security.taxability.rate.additional.excess" + ][period] = 0.85 # --- SET THRESHOLDS (MIN/MAX SWAP) --- # Always put the smaller number in 'base' and larger in 'adjusted_base' # Single - reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ - period - ] = min(oasdi_target_single, HI_SINGLE) + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.SINGLE" + ][period] = min(oasdi_target_single, HI_SINGLE) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.SINGLE" ][period] = max(oasdi_target_single, HI_SINGLE) # Joint - reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ - period - ] = min(oasdi_target_joint, HI_JOINT) + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.JOINT" + ][period] = min(oasdi_target_joint, HI_JOINT) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.JOINT" ][period] = max(oasdi_target_joint, HI_JOINT) @@ -226,12 +228,12 @@ def create_h6_reform(): # 1. Set Thresholds to "HI Only" mode # Base = $34k / $44k - reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ - elim_period - ] = HI_SINGLE - reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ - elim_period - ] = HI_JOINT + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.SINGLE" + ][elim_period] = HI_SINGLE + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.JOINT" + ][elim_period] = HI_JOINT # Adjusted = Infinity (Disable the second tier effectively) reform_payload[ @@ -260,12 +262,12 @@ def create_h6_reform(): ] = 0.35 # Tier 2 (Disabled via threshold, but zero out for safety) - reform_payload["gov.irs.social_security.taxability.rate.additional.benefit_cap"][ - elim_period - ] = 0.35 - reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ - elim_period - ] = 0.35 + reform_payload[ + "gov.irs.social_security.taxability.rate.additional.benefit_cap" + ][elim_period] = 0.35 + reform_payload[ + "gov.irs.social_security.taxability.rate.additional.excess" + ][elim_period] = 0.35 return reform_payload @@ -296,17 +298,23 @@ def create_h6_reform(): print(f"revenue_impact (B): {revenue_impact / 1e9:.2f}") # Calculate taxable payroll -taxable_ss_earnings = baseline.calculate("taxable_earnings_for_social_security") +taxable_ss_earnings = baseline.calculate( + "taxable_earnings_for_social_security" +) taxable_self_employment = baseline.calculate( "social_security_taxable_self_employment_income" ) -total_taxable_payroll = taxable_ss_earnings.sum() + taxable_self_employment.sum() +total_taxable_payroll = ( + taxable_ss_earnings.sum() + taxable_self_employment.sum() +) # Calculate SS benefits ss_benefits = baseline.calculate("social_security") total_ss_benefits = ss_benefits.sum() -est_rev_as_pct_of_taxable_payroll = 100 * revenue_impact / total_taxable_payroll +est_rev_as_pct_of_taxable_payroll = ( + 100 * revenue_impact / total_taxable_payroll +) # From https://www.ssa.gov/oact/solvency/provisions/tables/table_run133.html: target_rev_as_pct_of_taxable_payroll = -1.12 diff --git a/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py b/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py index 492a9d69..5ada2db9 100644 --- a/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py +++ b/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py @@ -2,7 +2,9 @@ import numpy as np # Read the file -df = pd.read_excel("SingleYearTRTables_TR2025.xlsx", sheet_name="VI.G9", header=None) +df = pd.read_excel( + "SingleYearTRTables_TR2025.xlsx", sheet_name="VI.G9", header=None +) print("DataFrame shape:", df.shape) print("\nChecking data types around row 66-70:") diff --git a/policyengine_us_data/datasets/cps/long_term/projection_utils.py b/policyengine_us_data/datasets/cps/long_term/projection_utils.py index 8aee4f3b..d0af8533 100644 --- a/policyengine_us_data/datasets/cps/long_term/projection_utils.py +++ b/policyengine_us_data/datasets/cps/long_term/projection_utils.py @@ -27,7 +27,9 @@ def build_household_age_matrix(sim, n_ages=86): n_households = len(household_ids_unique) X = np.zeros((n_households, n_ages)) - hh_id_to_idx = {hh_id: idx for idx, hh_id in enumerate(household_ids_unique)} + hh_id_to_idx = { + hh_id: idx for idx, hh_id in enumerate(household_ids_unique) + } for person_idx in range(len(age_person)): age = int(age_person.values[person_idx]) @@ -65,7 +67,9 @@ def get_pseudo_input_variables(sim): return pseudo_inputs -def create_household_year_h5(year, household_weights, base_dataset_path, output_dir): +def create_household_year_h5( + year, household_weights, base_dataset_path, output_dir +): """ Create a year-specific .h5 file with calibrated household weights. @@ -189,7 +193,9 @@ def calculate_year_statistics( Returns: Dictionary with year statistics and calibrated weights """ - income_tax_hh = sim.calculate("income_tax", period=year, map_to="household") + income_tax_hh = sim.calculate( + "income_tax", period=year, map_to="household" + ) income_tax_baseline_total = income_tax_hh.sum() income_tax_values = income_tax_hh.values @@ -200,7 +206,9 @@ def calculate_year_statistics( ss_values = None ss_target = None if use_ss: - ss_hh = sim.calculate("social_security", period=year, map_to="household") + ss_hh = sim.calculate( + "social_security", period=year, map_to="household" + ) ss_baseline_total = ss_hh.sum() ss_values = ss_hh.values diff --git a/policyengine_us_data/datasets/cps/long_term/run_household_projection.py b/policyengine_us_data/datasets/cps/long_term/run_household_projection.py index 1413efe4..30d1857a 100644 --- a/policyengine_us_data/datasets/cps/long_term/run_household_projection.py +++ b/policyengine_us_data/datasets/cps/long_term/run_household_projection.py @@ -105,9 +105,9 @@ def create_h6_reform(): # The swapped rate error is 14x smaller and aligns with tax-cutting intent. # Tier 1 (Base): HI ONLY (35%) - reform_payload["gov.irs.social_security.taxability.rate.base.benefit_cap"][ - period - ] = 0.35 + reform_payload[ + "gov.irs.social_security.taxability.rate.base.benefit_cap" + ][period] = 0.35 reform_payload["gov.irs.social_security.taxability.rate.base.excess"][ period ] = 0.35 @@ -116,25 +116,25 @@ def create_h6_reform(): reform_payload[ "gov.irs.social_security.taxability.rate.additional.benefit_cap" ][period] = 0.85 - reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ - period - ] = 0.85 + reform_payload[ + "gov.irs.social_security.taxability.rate.additional.excess" + ][period] = 0.85 # --- SET THRESHOLDS (MIN/MAX SWAP) --- # Always put the smaller number in 'base' and larger in 'adjusted_base' # Single - reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ - period - ] = min(oasdi_target_single, HI_SINGLE) + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.SINGLE" + ][period] = min(oasdi_target_single, HI_SINGLE) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.SINGLE" ][period] = max(oasdi_target_single, HI_SINGLE) # Joint - reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ - period - ] = min(oasdi_target_joint, HI_JOINT) + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.JOINT" + ][period] = min(oasdi_target_joint, HI_JOINT) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.JOINT" ][period] = max(oasdi_target_joint, HI_JOINT) @@ -158,12 +158,12 @@ def create_h6_reform(): # 1. Set Thresholds to "HI Only" mode # Base = $34k / $44k - reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ - elim_period - ] = HI_SINGLE - reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ - elim_period - ] = HI_JOINT + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.SINGLE" + ][elim_period] = HI_SINGLE + reform_payload[ + "gov.irs.social_security.taxability.threshold.base.main.JOINT" + ][elim_period] = HI_JOINT # Adjusted = Infinity (Disable the second tier effectively) reform_payload[ @@ -192,12 +192,12 @@ def create_h6_reform(): ] = 0.35 # Tier 2 (Disabled via threshold, but zero out for safety) - reform_payload["gov.irs.social_security.taxability.rate.additional.benefit_cap"][ - elim_period - ] = 0.35 - reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ - elim_period - ] = 0.35 + reform_payload[ + "gov.irs.social_security.taxability.rate.additional.benefit_cap" + ][elim_period] = 0.35 + reform_payload[ + "gov.irs.social_security.taxability.rate.additional.excess" + ][elim_period] = 0.35 # Create the Reform Object from policyengine_core.reforms import Reform @@ -242,14 +242,18 @@ def create_h6_reform(): if USE_PAYROLL: sys.argv.remove("--use-payroll") if not USE_GREG: - print("Warning: --use-payroll requires --greg, enabling GREG automatically") + print( + "Warning: --use-payroll requires --greg, enabling GREG automatically" + ) USE_GREG = True USE_H6_REFORM = "--use-h6-reform" in sys.argv if USE_H6_REFORM: sys.argv.remove("--use-h6-reform") if not USE_GREG: - print("Warning: --use-h6-reform requires --greg, enabling GREG automatically") + print( + "Warning: --use-h6-reform requires --greg, enabling GREG automatically" + ) USE_GREG = True from ssa_data import load_h6_income_rate_change @@ -257,7 +261,9 @@ def create_h6_reform(): if USE_TOB: sys.argv.remove("--use-tob") if not USE_GREG: - print("Warning: --use-tob requires --greg, enabling GREG automatically") + print( + "Warning: --use-tob requires --greg, enabling GREG automatically" + ) USE_GREG = True from ssa_data import load_oasdi_tob_projections, load_hi_tob_projections @@ -314,7 +320,9 @@ def create_h6_reform(): print("STEP 1: DEMOGRAPHIC PROJECTIONS") print("=" * 70) -target_matrix = load_ssa_age_projections(start_year=START_YEAR, end_year=END_YEAR) +target_matrix = load_ssa_age_projections( + start_year=START_YEAR, end_year=END_YEAR +) n_years = target_matrix.shape[1] n_ages = target_matrix.shape[0] @@ -382,7 +390,9 @@ def create_h6_reform(): sim = Microsimulation(dataset=BASE_DATASET_PATH) - income_tax_hh = sim.calculate("income_tax", period=year, map_to="household") + income_tax_hh = sim.calculate( + "income_tax", period=year, map_to="household" + ) income_tax_baseline_total = income_tax_hh.sum() income_tax_values = income_tax_hh.values @@ -395,7 +405,9 @@ def create_h6_reform(): ss_values = None ss_target = None if USE_SS: - ss_hh = sim.calculate("social_security", period=year, map_to="household") + ss_hh = sim.calculate( + "social_security", period=year, map_to="household" + ) ss_values = ss_hh.values ss_target = load_ssa_benefit_projections(year) if year in display_years: @@ -440,7 +452,9 @@ def create_h6_reform(): else: # Create and apply H6 reform h6_reform = create_h6_reform() - reform_sim = Microsimulation(dataset=BASE_DATASET_PATH, reform=h6_reform) + reform_sim = Microsimulation( + dataset=BASE_DATASET_PATH, reform=h6_reform + ) # Calculate reform income tax income_tax_reform_hh = reform_sim.calculate( @@ -458,7 +472,9 @@ def create_h6_reform(): # Debug output for key years if year in display_years: - h6_impact_baseline = np.sum(h6_income_values * baseline_weights) + h6_impact_baseline = np.sum( + h6_income_values * baseline_weights + ) print( f" [DEBUG {year}] H6 baseline revenue: ${h6_impact_baseline / 1e9:.3f}B, target: ${h6_revenue_target / 1e9:.3f}B" ) @@ -531,9 +547,13 @@ def create_h6_reform(): f"largest: {max_neg:,.0f}" ) else: - print(f" [DEBUG {year}] Negative weights: 0 (all weights non-negative)") + print( + f" [DEBUG {year}] Negative weights: 0 (all weights non-negative)" + ) - if year in display_years and (USE_SS or USE_PAYROLL or USE_H6_REFORM or USE_TOB): + if year in display_years and ( + USE_SS or USE_PAYROLL or USE_H6_REFORM or USE_TOB + ): if USE_SS: ss_achieved = np.sum(ss_values * w_new) print( @@ -547,7 +567,9 @@ def create_h6_reform(): if USE_H6_REFORM and h6_revenue_target is not None: h6_revenue_achieved = np.sum(h6_income_values * w_new) error_pct = ( - (h6_revenue_achieved - h6_revenue_target) / abs(h6_revenue_target) * 100 + (h6_revenue_achieved - h6_revenue_target) + / abs(h6_revenue_target) + * 100 if h6_revenue_target != 0 else 0 ) @@ -571,7 +593,9 @@ def create_h6_reform(): total_population[year_idx] = np.sum(y_target) if SAVE_H5: - h5_path = create_household_year_h5(year, w_new, BASE_DATASET_PATH, OUTPUT_DIR) + h5_path = create_household_year_h5( + year, w_new, BASE_DATASET_PATH, OUTPUT_DIR + ) if year in display_years: print(f" Saved {year}.h5") diff --git a/policyengine_us_data/datasets/cps/small_enhanced_cps.py b/policyengine_us_data/datasets/cps/small_enhanced_cps.py index a1508032..53607d03 100644 --- a/policyengine_us_data/datasets/cps/small_enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/small_enhanced_cps.py @@ -35,7 +35,9 @@ def create_small_ecps(): data[variable] = {} for time_period in simulation.get_holder(variable).get_known_periods(): values = simulation.get_holder(variable).get_array(time_period) - if simulation.tax_benefit_system.variables.get(variable).value_type in ( + if simulation.tax_benefit_system.variables.get( + variable + ).value_type in ( Enum, str, ): @@ -112,7 +114,8 @@ def create_sparse_ecps(): for time_period in sim.get_holder(variable).get_known_periods(): values = sim.get_holder(variable).get_array(time_period) if ( - sim.tax_benefit_system.variables.get(variable).value_type in (Enum, str) + sim.tax_benefit_system.variables.get(variable).value_type + in (Enum, str) and variable != "county_fips" ): values = values.decode_to_str().astype("S") @@ -135,7 +138,9 @@ def create_sparse_ecps(): ] missing = [v for v in critical_vars if v not in data] if missing: - raise ValueError(f"create_sparse_ecps: missing critical variables: {missing}") + raise ValueError( + f"create_sparse_ecps: missing critical variables: {missing}" + ) logging.info(f"create_sparse_ecps: data dict has {len(data)} variables") output_path = STORAGE_FOLDER / "sparse_enhanced_cps_2024.h5" @@ -150,7 +155,9 @@ def create_sparse_ecps(): raise ValueError( f"create_sparse_ecps: output file only {file_size:,} bytes (expected > 1MB)" ) - logging.info(f"create_sparse_ecps: wrote {file_size / 1e6:.1f}MB to {output_path}") + logging.info( + f"create_sparse_ecps: wrote {file_size / 1e6:.1f}MB to {output_path}" + ) if __name__ == "__main__": diff --git a/policyengine_us_data/datasets/puf/irs_puf.py b/policyengine_us_data/datasets/puf/irs_puf.py index c357cd56..dd77890a 100644 --- a/policyengine_us_data/datasets/puf/irs_puf.py +++ b/policyengine_us_data/datasets/puf/irs_puf.py @@ -30,7 +30,9 @@ def generate(self): with pd.HDFStore(self.file_path, mode="w") as storage: storage.put("puf", pd.read_csv(puf_file_path)) - storage.put("puf_demographics", pd.read_csv(puf_demographics_file_path)) + storage.put( + "puf_demographics", pd.read_csv(puf_demographics_file_path) + ) class IRS_PUF_2015(IRS_PUF): diff --git a/policyengine_us_data/datasets/puf/puf.py b/policyengine_us_data/datasets/puf/puf.py index 040098c1..ae8cf4fe 100644 --- a/policyengine_us_data/datasets/puf/puf.py +++ b/policyengine_us_data/datasets/puf/puf.py @@ -109,10 +109,14 @@ def simulate_w2_and_ubia_from_puf(puf, *, seed=None, diagnostics=True): ) revenues = np.maximum(qbi, 0) / margins - logit = logit_params["intercept"] + logit_params["slope_per_dollar"] * revenues + logit = ( + logit_params["intercept"] + logit_params["slope_per_dollar"] * revenues + ) # Set p = 0 when simulated receipts == 0 (no revenue means no payroll) - pr_has_employees = np.where(revenues == 0.0, 0.0, 1.0 / (1.0 + np.exp(-logit))) + pr_has_employees = np.where( + revenues == 0.0, 0.0, 1.0 / (1.0 + np.exp(-logit)) + ) has_employees = rng.binomial(1, pr_has_employees) # Labor share simulation @@ -121,7 +125,8 @@ def simulate_w2_and_ubia_from_puf(puf, *, seed=None, diagnostics=True): labor_ratios = np.where( is_rental, rng.beta(rental_beta_a, rental_beta_b, qbi.size) * rental_scale, - rng.beta(non_rental_beta_a, non_rental_beta_b, qbi.size) * non_rental_scale, + rng.beta(non_rental_beta_a, non_rental_beta_b, qbi.size) + * non_rental_scale, ) w2_wages = revenues * labor_ratios * has_employees @@ -204,7 +209,9 @@ def impute_missing_demographics( .fillna(0) ) - puf_with_demographics = puf_with_demographics.sample(n=10_000, random_state=0) + puf_with_demographics = puf_with_demographics.sample( + n=10_000, random_state=0 + ) DEMOGRAPHIC_VARIABLES = [ "AGEDP1", @@ -404,7 +411,9 @@ def preprocess_puf(puf: pd.DataFrame) -> pd.DataFrame: - puf["E25920"].fillna(0) - puf["E25960"].fillna(0) ) != 0 - partnership_se = np.where(has_partnership, gross_se - schedule_c_f_income, 0) + partnership_se = np.where( + has_partnership, gross_se - schedule_c_f_income, 0 + ) puf["partnership_se_income"] = partnership_se # --- Qualified Business Income Deduction (QBID) simulation --- @@ -415,9 +424,9 @@ def preprocess_puf(puf: pd.DataFrame) -> pd.DataFrame: puf_qbi_sources_for_sstb = puf[QBI_PARAMS["sstb_prob_map_by_name"].keys()] largest_qbi_source_name = puf_qbi_sources_for_sstb.idxmax(axis=1) - pr_sstb = largest_qbi_source_name.map(QBI_PARAMS["sstb_prob_map_by_name"]).fillna( - 0.0 - ) + pr_sstb = largest_qbi_source_name.map( + QBI_PARAMS["sstb_prob_map_by_name"] + ).fillna(0.0) puf["business_is_sstb"] = np.random.binomial(n=1, p=pr_sstb) reit_params = QBI_PARAMS["reit_ptp_income_distribution"] @@ -544,9 +553,9 @@ def generate(self): current_index = uprating[uprating.Variable == variable][ self.time_period ].values[0] - start_index = uprating[uprating.Variable == variable][2021].values[ - 0 - ] + start_index = uprating[uprating.Variable == variable][ + 2021 + ].values[0] growth = current_index / start_index arrays[variable] = arrays[variable] * growth self.save_dataset(arrays) @@ -626,7 +635,9 @@ def generate(self): for group in groups_assumed_to_be_tax_unit_like: self.holder[f"{group}_id"] = self.holder["tax_unit_id"] - self.holder[f"person_{group}_id"] = self.holder["person_tax_unit_id"] + self.holder[f"person_{group}_id"] = self.holder[ + "person_tax_unit_id" + ] for key in self.holder: if key == "filing_status": @@ -678,7 +689,9 @@ def add_filer(self, row, tax_unit_id): # Assume all of the interest deduction is the filer's deductible mortgage interest - self.holder["deductible_mortgage_interest"].append(row["interest_deduction"]) + self.holder["deductible_mortgage_interest"].append( + row["interest_deduction"] + ) for key in self.available_financial_vars: if key == "deductible_mortgage_interest": diff --git a/policyengine_us_data/datasets/scf/fed_scf.py b/policyengine_us_data/datasets/scf/fed_scf.py index 8c0d8e8c..f67a2c07 100644 --- a/policyengine_us_data/datasets/scf/fed_scf.py +++ b/policyengine_us_data/datasets/scf/fed_scf.py @@ -32,12 +32,16 @@ def load(self): def generate(self): if self._scf_download_url is None: - raise ValueError(f"No raw SCF data URL known for year {self.time_period}.") + raise ValueError( + f"No raw SCF data URL known for year {self.time_period}." + ) url = self._scf_download_url response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get("content-length", 200e6)) + total_size_in_bytes = int( + response.headers.get("content-length", 200e6) + ) progress_bar = tqdm( total=total_size_in_bytes, unit="iB", @@ -45,7 +49,9 @@ def generate(self): desc="Downloading SCF", ) if response.status_code == 404: - raise FileNotFoundError("Received a 404 response when fetching the data.") + raise FileNotFoundError( + "Received a 404 response when fetching the data." + ) with BytesIO() as file: content_length_actual = 0 for data in response.iter_content(int(1e6)): @@ -59,7 +65,9 @@ def generate(self): zipfile = ZipFile(file) with pd.HDFStore(self.file_path, mode="w") as storage: # Find the Stata file, which should be the only .dta file in the zip - dta_files = [f for f in zipfile.namelist() if f.endswith(".dta")] + dta_files = [ + f for f in zipfile.namelist() if f.endswith(".dta") + ] if not dta_files: raise FileNotFoundError( "No .dta file found in the SCF zip archive." diff --git a/policyengine_us_data/datasets/scf/scf.py b/policyengine_us_data/datasets/scf/scf.py index 3f2f11a7..1567fbbb 100644 --- a/policyengine_us_data/datasets/scf/scf.py +++ b/policyengine_us_data/datasets/scf/scf.py @@ -55,7 +55,9 @@ def generate(self): try: scf[key] = np.array(scf[key]) except Exception as e: - print(f"Warning: Could not convert {key} to numpy array: {e}") + print( + f"Warning: Could not convert {key} to numpy array: {e}" + ) self.save_dataset(scf) @@ -108,7 +110,9 @@ def downsample(self, frac: float): # Store original dtypes before modifying original_data: dict = self.load_dataset() - original_dtypes = {key: original_data[key].dtype for key in original_data} + original_dtypes = { + key: original_data[key].dtype for key in original_data + } sim = Microsimulation(dataset=self) sim.subsample(frac=frac) @@ -185,13 +189,17 @@ def rename_columns_to_match_cps(scf: dict, raw_data: pd.DataFrame) -> None: 4: 4, # Asian 5: 7, # Other } - scf["cps_race"] = raw_data["racecl5"].map(race_map).fillna(6).astype(int).values + scf["cps_race"] = ( + raw_data["racecl5"].map(race_map).fillna(6).astype(int).values + ) # Hispanic indicator scf["is_hispanic"] = (raw_data["racecl5"] == 3).values # Children in household if "kids" in raw_data.columns: - scf["own_children_in_household"] = raw_data["kids"].fillna(0).astype(int).values + scf["own_children_in_household"] = ( + raw_data["kids"].fillna(0).astype(int).values + ) # Rent if "rent" in raw_data.columns: @@ -199,7 +207,9 @@ def rename_columns_to_match_cps(scf: dict, raw_data: pd.DataFrame) -> None: # Vehicle loan (auto loan) if "veh_inst" in raw_data.columns: - scf["total_vehicle_installments"] = raw_data["veh_inst"].fillna(0).values + scf["total_vehicle_installments"] = ( + raw_data["veh_inst"].fillna(0).values + ) # Marital status if "married" in raw_data.columns: @@ -259,7 +269,9 @@ def add_auto_loan_interest(scf: dict, year: int) -> None: logger.error( f"Network error downloading SCF data for year {year}: {str(e)}" ) - raise RuntimeError(f"Failed to download SCF data for year {year}") from e + raise RuntimeError( + f"Failed to download SCF data for year {year}" + ) from e # Process zip file try: @@ -270,7 +282,9 @@ def add_auto_loan_interest(scf: dict, year: int) -> None: dta_files = [f for f in z.namelist() if f.endswith(".dta")] if not dta_files: logger.error(f"No Stata files found in zip for year {year}") - raise ValueError(f"No Stata files found in zip for year {year}") + raise ValueError( + f"No Stata files found in zip for year {year}" + ) logger.info(f"Found Stata files: {dta_files}") @@ -284,14 +298,18 @@ def add_auto_loan_interest(scf: dict, year: int) -> None: ) logger.info(f"Read DataFrame with shape {df.shape}") except Exception as e: - logger.error(f"Error reading Stata file for year {year}: {str(e)}") + logger.error( + f"Error reading Stata file for year {year}: {str(e)}" + ) raise RuntimeError( f"Failed to process Stata file for year {year}" ) from e except zipfile.BadZipFile as e: logger.error(f"Bad zip file for year {year}: {str(e)}") - raise RuntimeError(f"Downloaded zip file is corrupt for year {year}") from e + raise RuntimeError( + f"Downloaded zip file is corrupt for year {year}" + ) from e # Process the interest data and add to final SCF dictionary auto_df = df[IDENTIFYER_COLUMNS + AUTO_LOAN_COLUMNS].copy() diff --git a/policyengine_us_data/datasets/sipp/sipp.py b/policyengine_us_data/datasets/sipp/sipp.py index d7708266..bf8b75dd 100644 --- a/policyengine_us_data/datasets/sipp/sipp.py +++ b/policyengine_us_data/datasets/sipp/sipp.py @@ -68,7 +68,8 @@ def train_tip_model(): ) # Sum tip columns (AJB*_TXAMT + TJB*_TXAMT) across all jobs. df["tip_income"] = ( - df[df.columns[df.columns.str.contains("TXAMT")]].fillna(0).sum(axis=1) * 12 + df[df.columns[df.columns.str.contains("TXAMT")]].fillna(0).sum(axis=1) + * 12 ) df["employment_income"] = df.TPTOTINC * 12 df["is_under_18"] = (df.TAGE < 18) & (df.MONTHCODE == 12) diff --git a/policyengine_us_data/db/create_database_tables.py b/policyengine_us_data/db/create_database_tables.py index d89bad31..be22fcbb 100644 --- a/policyengine_us_data/db/create_database_tables.py +++ b/policyengine_us_data/db/create_database_tables.py @@ -39,7 +39,9 @@ class Stratum(SQLModel, table=True): description="Unique identifier for the stratum.", ) definition_hash: str = Field( - sa_column_kwargs={"comment": "SHA-256 hash of the stratum's constraints."}, + sa_column_kwargs={ + "comment": "SHA-256 hash of the stratum's constraints." + }, max_length=64, ) parent_stratum_id: Optional[int] = Field( @@ -87,7 +89,9 @@ class StratumConstraint(SQLModel, table=True): primary_key=True, description="The comparison operator (==, !=, >, >=, <, <=).", ) - value: str = Field(description="The value for the constraint rule (e.g., '25').") + value: str = Field( + description="The value for the constraint rule (e.g., '25')." + ) notes: Optional[str] = Field( default=None, description="Optional notes about the constraint." ) @@ -113,7 +117,9 @@ class Target(SQLModel, table=True): variable: str = Field( description="A variable defined in policyengine-us (e.g., 'income_tax')." ) - period: int = Field(description="The time period for the data, typically a year.") + period: int = Field( + description="The time period for the data, typically a year." + ) stratum_id: int = Field(foreign_key="strata.stratum_id", index=True) reform_id: int = Field( default=0, @@ -150,13 +156,19 @@ def calculate_definition_hash(mapper, connection, target: Stratum): Calculate and set the definition_hash before saving a Stratum instance. """ constraints_history = get_history(target, "constraints_rel") - if not (constraints_history.has_changes() or target.definition_hash is None): + if not ( + constraints_history.has_changes() or target.definition_hash is None + ): return if not target.constraints_rel: # Handle cases with no constraints # Include parent_stratum_id to make hash unique per parent - parent_str = str(target.parent_stratum_id) if target.parent_stratum_id else "" - target.definition_hash = hashlib.sha256(parent_str.encode("utf-8")).hexdigest() + parent_str = ( + str(target.parent_stratum_id) if target.parent_stratum_id else "" + ) + target.definition_hash = hashlib.sha256( + parent_str.encode("utf-8") + ).hexdigest() return constraint_strings = [ @@ -166,7 +178,9 @@ def calculate_definition_hash(mapper, connection, target: Stratum): constraint_strings.sort() # Include parent_stratum_id in the hash to ensure uniqueness per parent - parent_str = str(target.parent_stratum_id) if target.parent_stratum_id else "" + parent_str = ( + str(target.parent_stratum_id) if target.parent_stratum_id else "" + ) fingerprint_text = parent_str + "\n" + "\n".join(constraint_strings) h = hashlib.sha256(fingerprint_text.encode("utf-8")) target.definition_hash = h.hexdigest() @@ -227,7 +241,10 @@ def _validate_geographic_consistency(parent_rows, child_constraints): ) # CD must belong to the parent state. - if "state_fips" in parent_dict and "congressional_district_geoid" in child_dict: + if ( + "state_fips" in parent_dict + and "congressional_district_geoid" in child_dict + ): parent_state = int(parent_dict["state_fips"]) child_cd = int(child_dict["congressional_district_geoid"]) cd_state = child_cd // 100 @@ -271,7 +288,8 @@ def validate_parent_child_constraints(mapper, connection, target: Stratum): return child_set = { - (c.constraint_variable, c.operation, c.value) for c in target.constraints_rel + (c.constraint_variable, c.operation, c.value) + for c in target.constraints_rel } for var, op, val in parent_rows: diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index 8f6f051c..aa656c9d 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -45,12 +45,16 @@ def fetch_congressional_districts(year): df = df[df["district_number"] >= 0].copy() # Filter out statewide summary records for multi-district states - df["n_districts"] = df.groupby("state_fips")["state_fips"].transform("count") + df["n_districts"] = df.groupby("state_fips")["state_fips"].transform( + "count" + ) df = df[(df["n_districts"] == 1) | (df["district_number"] > 0)].copy() df = df.drop(columns=["n_districts"]) df.loc[df["district_number"] == 0, "district_number"] = 1 - df["congressional_district_geoid"] = df["state_fips"] * 100 + df["district_number"] + df["congressional_district_geoid"] = ( + df["state_fips"] * 100 + df["district_number"] + ) df = df[ [ @@ -126,7 +130,9 @@ def main(): # Fetch congressional district data cd_df = fetch_congressional_districts(year) - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) with Session(engine) as session: @@ -151,7 +157,9 @@ def main(): # Create state-level strata unique_states = cd_df["state_fips"].unique() for state_fips in sorted(unique_states): - state_name = STATE_NAMES.get(state_fips, f"State FIPS {state_fips}") + state_name = STATE_NAMES.get( + state_fips, f"State FIPS {state_fips}" + ) state_stratum = Stratum( parent_stratum_id=us_stratum_id, notes=state_name, diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index db5e54da..1a12f372 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -66,7 +66,9 @@ def transform_age_data(age_data, docs): # Filter out Puerto Rico's district and state records # 5001800US7298 = 118th Congress, 5001900US7298 = 119th Congress df_geos = df_data[ - ~df_data["ucgid_str"].isin(["5001800US7298", "5001900US7298", "0400000US72"]) + ~df_data["ucgid_str"].isin( + ["5001800US7298", "5001900US7298", "0400000US72"] + ) ].copy() df = df_geos[["ucgid_str"] + AGE_COLS] @@ -104,7 +106,9 @@ def load_age_data(df_long, geo, year): raise ValueError('geo must be one of "National", "State", "District"') # Prepare to load data ----------- - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) with Session(engine) as session: diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index f2b17795..aa8122a5 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -104,7 +104,9 @@ def make_records( f"WARNING: A59664 values appear to be in thousands (max={max_value:,.0f})" ) print("The IRS may have fixed their data inconsistency.") - print("Please verify and remove the special case handling if confirmed.") + print( + "Please verify and remove the special case handling if confirmed." + ) # Don't apply the fix - data appears to already be in thousands else: # Convert from dollars to thousands to match other columns @@ -160,7 +162,9 @@ def convert_district_data( """Transforms data from pre- to post- 2020 census districts""" df = input_df.copy() old_districts_df = df[df["ucgid_str"].str.startswith("5001800US")].copy() - old_districts_df = old_districts_df.sort_values("ucgid_str").reset_index(drop=True) + old_districts_df = old_districts_df.sort_values("ucgid_str").reset_index( + drop=True + ) old_values = old_districts_df["target_value"].to_numpy() new_values = mapping_matrix.T @ old_values @@ -285,15 +289,19 @@ def transform_soi_data(raw_df): # State ------------------- # You've got agi_stub == 0 in here, which you want to use any time you don't want to # divide data by AGI classes (i.e., agi_stub) - state_df = raw_df.copy().loc[(raw_df.STATE != "US") & (raw_df.CONG_DISTRICT == 0)] - state_df["ucgid_str"] = "0400000US" + state_df["STATEFIPS"].astype(str).str.zfill(2) + state_df = raw_df.copy().loc[ + (raw_df.STATE != "US") & (raw_df.CONG_DISTRICT == 0) + ] + state_df["ucgid_str"] = "0400000US" + state_df["STATEFIPS"].astype( + str + ).str.zfill(2) # District ------------------ district_df = raw_df.copy().loc[(raw_df.CONG_DISTRICT > 0)] - max_cong_district_by_state = raw_df.groupby("STATE")["CONG_DISTRICT"].transform( - "max" - ) + max_cong_district_by_state = raw_df.groupby("STATE")[ + "CONG_DISTRICT" + ].transform("max") district_df = raw_df.copy().loc[ (raw_df["CONG_DISTRICT"] > 0) | (max_cong_district_by_state == 0) ] @@ -362,7 +370,9 @@ def transform_soi_data(raw_df): # Pre- to Post- 2020 Census redisticting mapping = get_district_mapping() converted = [ - convert_district_data(r, mapping["mapping_matrix"], mapping["new_codes"]) + convert_district_data( + r, mapping["mapping_matrix"], mapping["new_codes"] + ) for r in records ] @@ -372,7 +382,9 @@ def transform_soi_data(raw_df): def load_soi_data(long_dfs, year): """Load a list of databases into the db, critically dependent on order""" - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) session = Session(engine) @@ -446,7 +458,9 @@ def load_soi_data(long_dfs, year): filer_strata["state"][state_fips] = state_filer_stratum.stratum_id # District filer strata - for district_geoid, district_geo_stratum_id in geo_strata["district"].items(): + for district_geoid, district_geo_stratum_id in geo_strata[ + "district" + ].items(): # Check if district filer stratum exists district_filer_stratum = ( session.query(Stratum) @@ -478,7 +492,9 @@ def load_soi_data(long_dfs, year): session.add(district_filer_stratum) session.flush() - filer_strata["district"][district_geoid] = district_filer_stratum.stratum_id + filer_strata["district"][ + district_geoid + ] = district_filer_stratum.stratum_id session.commit() @@ -509,7 +525,9 @@ def load_soi_data(long_dfs, year): ) ] elif geo_info["type"] == "state": - parent_stratum_id = filer_strata["state"][geo_info["state_fips"]] + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] note = f"State FIPS {geo_info['state_fips']} EITC received with {n_children} children (filers)" constraints = [ StratumConstraint( @@ -618,7 +636,9 @@ def load_soi_data(long_dfs, year): # Store lookup for later use if geo_info["type"] == "national": - eitc_stratum_lookup["national"][n_children] = new_stratum.stratum_id + eitc_stratum_lookup["national"][ + n_children + ] = new_stratum.stratum_id elif geo_info["type"] == "state": key = (geo_info["state_fips"], n_children) eitc_stratum_lookup["state"][key] = new_stratum.stratum_id @@ -632,7 +652,8 @@ def load_soi_data(long_dfs, year): first_agi_index = [ i for i in range(len(long_dfs)) - if long_dfs[i][["target_variable"]].values[0] == "adjusted_gross_income" + if long_dfs[i][["target_variable"]].values[0] + == "adjusted_gross_income" and long_dfs[i][["breakdown_variable"]].values[0] == "one" ][0] for j in range(8, first_agi_index, 2): @@ -655,13 +676,17 @@ def load_soi_data(long_dfs, year): parent_stratum_id = filer_strata["national"] geo_description = "National" elif geo_info["type"] == "state": - parent_stratum_id = filer_strata["state"][geo_info["state_fips"]] + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] geo_description = f"State {geo_info['state_fips']}" elif geo_info["type"] == "district": parent_stratum_id = filer_strata["district"][ geo_info["congressional_district_geoid"] ] - geo_description = f"CD {geo_info['congressional_district_geoid']}" + geo_description = ( + f"CD {geo_info['congressional_district_geoid']}" + ) # Create child stratum with constraint for this IRS variable # Note: This stratum will have the constraint that amount_variable > 0 @@ -716,7 +741,9 @@ def load_soi_data(long_dfs, year): StratumConstraint( constraint_variable="congressional_district_geoid", operation="==", - value=str(geo_info["congressional_district_geoid"]), + value=str( + geo_info["congressional_district_geoid"] + ), ) ) @@ -778,7 +805,9 @@ def load_soi_data(long_dfs, year): elif geo_info["type"] == "district": stratum = session.get( Stratum, - filer_strata["district"][geo_info["congressional_district_geoid"]], + filer_strata["district"][ + geo_info["congressional_district_geoid"] + ], ) # Check if target already exists @@ -793,7 +822,9 @@ def load_soi_data(long_dfs, year): ) if existing_target: - existing_target.value = agi_values.iloc[i][["target_value"]].values[0] + existing_target.value = agi_values.iloc[i][ + ["target_value"] + ].values[0] else: stratum.targets_rel.append( Target( @@ -870,7 +901,9 @@ def load_soi_data(long_dfs, year): person_count = agi_df.iloc[i][["target_value"]].values[0] if geo_info["type"] == "state": - parent_stratum_id = filer_strata["state"][geo_info["state_fips"]] + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] note = f"State FIPS {geo_info['state_fips']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" constraints = [ StratumConstraint( @@ -967,9 +1000,9 @@ def load_soi_data(long_dfs, year): session.flush() if geo_info["type"] == "state": - agi_stratum_lookup["state"][geo_info["state_fips"]] = ( - new_stratum.stratum_id - ) + agi_stratum_lookup["state"][ + geo_info["state_fips"] + ] = new_stratum.stratum_id elif geo_info["type"] == "district": agi_stratum_lookup["district"][ geo_info["congressional_district_geoid"] diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index 2c467799..dfc19cdc 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -116,7 +116,9 @@ def transform_administrative_medicaid_data(state_admin_df, year): ].sort_values("Reporting Period", ascending=False) if not state_history.empty: - fallback_value = state_history.iloc[0]["Total Medicaid Enrollment"] + fallback_value = state_history.iloc[0][ + "Total Medicaid Enrollment" + ] fallback_period = state_history.iloc[0]["Reporting Period"] print( f" {state_abbrev}: Using {fallback_value:,.0f} from period {fallback_period}" @@ -151,7 +153,9 @@ def transform_survey_medicaid_data(cd_survey_df): def load_medicaid_data(long_state, long_cd, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) with Session(engine) as session: @@ -218,7 +222,9 @@ def load_medicaid_data(long_state, long_cd, year): ) session.add(new_stratum) session.flush() - medicaid_stratum_lookup["state"][state_fips] = new_stratum.stratum_id + medicaid_stratum_lookup["state"][ + state_fips + ] = new_stratum.stratum_id # District ------------------- if long_cd is None: diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py index 0e87aa84..2b78b6d6 100644 --- a/policyengine_us_data/db/etl_national_targets.py +++ b/policyengine_us_data/db/etl_national_targets.py @@ -423,10 +423,14 @@ def transform_national_targets(raw_targets): # Note: income_tax_positive from CBO and eitc from Treasury need # filer constraint cbo_non_tax = [ - t for t in raw_targets["cbo_targets"] if t["variable"] != "income_tax_positive" + t + for t in raw_targets["cbo_targets"] + if t["variable"] != "income_tax_positive" ] cbo_tax = [ - t for t in raw_targets["cbo_targets"] if t["variable"] == "income_tax_positive" + t + for t in raw_targets["cbo_targets"] + if t["variable"] == "income_tax_positive" ] all_direct_targets = raw_targets["direct_sum_targets"] + cbo_non_tax @@ -439,10 +443,14 @@ def transform_national_targets(raw_targets): ) direct_df = ( - pd.DataFrame(all_direct_targets) if all_direct_targets else pd.DataFrame() + pd.DataFrame(all_direct_targets) + if all_direct_targets + else pd.DataFrame() ) tax_filer_df = ( - pd.DataFrame(all_tax_filer_targets) if all_tax_filer_targets else pd.DataFrame() + pd.DataFrame(all_tax_filer_targets) + if all_tax_filer_targets + else pd.DataFrame() ) # Conditional targets stay as list for special processing @@ -451,7 +459,9 @@ def transform_national_targets(raw_targets): return direct_df, tax_filer_df, conditional_targets -def load_national_targets(direct_targets_df, tax_filer_df, conditional_targets): +def load_national_targets( + direct_targets_df, tax_filer_df, conditional_targets +): """ Load national targets into the database. @@ -465,13 +475,17 @@ def load_national_targets(direct_targets_df, tax_filer_df, conditional_targets): List of conditional count targets requiring strata """ - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) with Session(engine) as session: # Get the national stratum us_stratum = ( - session.query(Stratum).filter(Stratum.parent_stratum_id == None).first() + session.query(Stratum) + .filter(Stratum.parent_stratum_id == None) + .first() ) if not us_stratum: @@ -497,7 +511,9 @@ def load_national_targets(direct_targets_df, tax_filer_df, conditional_targets): notes_parts = [] if pd.notna(target_data.get("notes")): notes_parts.append(target_data["notes"]) - notes_parts.append(f"Source: {target_data.get('source', 'Unknown')}") + notes_parts.append( + f"Source: {target_data.get('source', 'Unknown')}" + ) combined_notes = " | ".join(notes_parts) if existing_target: @@ -567,7 +583,9 @@ def load_national_targets(direct_targets_df, tax_filer_df, conditional_targets): notes_parts = [] if pd.notna(target_data.get("notes")): notes_parts.append(target_data["notes"]) - notes_parts.append(f"Source: {target_data.get('source', 'Unknown')}") + notes_parts.append( + f"Source: {target_data.get('source', 'Unknown')}" + ) combined_notes = " | ".join(notes_parts) if existing_target: @@ -681,17 +699,23 @@ def load_national_targets(direct_targets_df, tax_filer_df, conditional_targets): ] session.add(new_stratum) - print(f"Created stratum and target for {constraint_var} enrollment") + print( + f"Created stratum and target for {constraint_var} enrollment" + ) session.commit() total_targets = ( - len(direct_targets_df) + len(tax_filer_df) + len(conditional_targets) + len(direct_targets_df) + + len(tax_filer_df) + + len(conditional_targets) ) print(f"\nSuccessfully loaded {total_targets} national targets") print(f" - {len(direct_targets_df)} direct sum targets") print(f" - {len(tax_filer_df)} tax filer targets") - print(f" - {len(conditional_targets)} enrollment count targets (as strata)") + print( + f" - {len(conditional_targets)} enrollment count targets (as strata)" + ) def main(): @@ -706,8 +730,8 @@ def main(): # Transform print("Transforming targets...") - direct_targets_df, tax_filer_df, conditional_targets = transform_national_targets( - raw_targets + direct_targets_df, tax_filer_df, conditional_targets = ( + transform_national_targets(raw_targets) ) # Load diff --git a/policyengine_us_data/db/etl_pregnancy.py b/policyengine_us_data/db/etl_pregnancy.py index e8756cfb..c237d262 100644 --- a/policyengine_us_data/db/etl_pregnancy.py +++ b/policyengine_us_data/db/etl_pregnancy.py @@ -219,7 +219,9 @@ def transform_pregnancy_data( df = births_df.merge(pop_df, on="state_abbrev") df["state_fips"] = df["state_abbrev"].map(STATE_ABBREV_TO_FIPS) # Point-in-time pregnancy count. - df["pregnancy_target"] = (df["births"] * PREGNANCY_DURATION_FRACTION).round() + df["pregnancy_target"] = ( + df["births"] * PREGNANCY_DURATION_FRACTION + ).round() # Rate for stochastic assignment in the CPS build. df["pregnancy_rate"] = ( df["births"] / df["female_15_44"] @@ -266,7 +268,9 @@ def load_pregnancy_data( for _, row in df.iterrows(): state_fips = int(row["state_fips"]) if state_fips not in geo_strata["state"]: - logger.warning(f"No geographic stratum for FIPS {state_fips}, skipping") + logger.warning( + f"No geographic stratum for FIPS {state_fips}, skipping" + ) continue parent_id = geo_strata["state"][state_fips] @@ -358,7 +362,9 @@ def main(): except Exception as e: logger.warning(f"ACS {acs_year} not available: {e}") if pop_df is None: - raise RuntimeError(f"No ACS population data for {year - 1} or {year - 2}") + raise RuntimeError( + f"No ACS population data for {year - 1} or {year - 2}" + ) df = transform_pregnancy_data(births_df, pop_df) diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index dc5975a4..48cb7e77 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -154,7 +154,9 @@ def transform_survey_snap_data(raw_df): def load_administrative_snap_data(df_states, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) with Session(engine) as session: @@ -242,7 +244,9 @@ def load_survey_snap_data(survey_df, year, snap_stratum_lookup): load_administrative_snap_data, so we don't recreate them. """ - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) with Session(engine) as session: diff --git a/policyengine_us_data/db/etl_state_income_tax.py b/policyengine_us_data/db/etl_state_income_tax.py index 95fbc285..a9ffa35c 100644 --- a/policyengine_us_data/db/etl_state_income_tax.py +++ b/policyengine_us_data/db/etl_state_income_tax.py @@ -320,7 +320,11 @@ def main(): # Print summary total_collections = transformed_df["income_tax_collections"].sum() states_with_tax = len( - [s for s in transformed_df["state_abbrev"] if s not in NO_INCOME_TAX_STATES] + [ + s + for s in transformed_df["state_abbrev"] + if s not in NO_INCOME_TAX_STATES + ] ) logger.info( @@ -333,7 +337,9 @@ def main(): # Print Ohio specifically (for the issue reference) ohio_row = transformed_df[transformed_df["state_abbrev"] == "OH"].iloc[0] - logger.info(f" Ohio (OH): ${ohio_row['income_tax_collections'] / 1e9:.2f}B") + logger.info( + f" Ohio (OH): ${ohio_row['income_tax_collections'] / 1e9:.2f}B" + ) if __name__ == "__main__": diff --git a/policyengine_us_data/db/validate_database.py b/policyengine_us_data/db/validate_database.py index b57a83c3..2fa819f2 100644 --- a/policyengine_us_data/db/validate_database.py +++ b/policyengine_us_data/db/validate_database.py @@ -9,7 +9,9 @@ import pandas as pd from policyengine_us.system import system -conn = sqlite3.connect("policyengine_us_data/storage/calibration/policy_data.db") +conn = sqlite3.connect( + "policyengine_us_data/storage/calibration/policy_data.db" +) stratum_constraints_df = pd.read_sql("SELECT * FROM stratum_constraints", conn) targets_df = pd.read_sql("SELECT * FROM targets", conn) diff --git a/policyengine_us_data/db/validate_hierarchy.py b/policyengine_us_data/db/validate_hierarchy.py index 1c555703..353c09ee 100644 --- a/policyengine_us_data/db/validate_hierarchy.py +++ b/policyengine_us_data/db/validate_hierarchy.py @@ -31,7 +31,9 @@ def validate_geographic_hierarchy(session): "ERROR: No US-level stratum found (should have parent_stratum_id = None)" ) else: - print(f"✓ US stratum found: {us_stratum.notes} (ID: {us_stratum.stratum_id})") + print( + f"✓ US stratum found: {us_stratum.notes} (ID: {us_stratum.stratum_id})" + ) # Check it has no constraints us_constraints = session.exec( @@ -87,10 +89,14 @@ def validate_geographic_hierarchy(session): c for c in constraints if c.constraint_variable == "state_fips" ] if not state_fips_constraint: - errors.append(f"ERROR: State '{state.notes}' has no state_fips constraint") + errors.append( + f"ERROR: State '{state.notes}' has no state_fips constraint" + ) else: state_ids[state.stratum_id] = state.notes - print(f" - {state.notes}: state_fips = {state_fips_constraint[0].value}") + print( + f" - {state.notes}: state_fips = {state_fips_constraint[0].value}" + ) # Check congressional districts print("\nChecking Congressional Districts...") @@ -106,10 +112,14 @@ def validate_geographic_hierarchy(session): ) ).all() constraint_vars = {c.constraint_variable for c in constraints} - if "congressional_district_geoid" in constraint_vars and constraint_vars <= { - "state_fips", - "congressional_district_geoid", - }: + if ( + "congressional_district_geoid" in constraint_vars + and constraint_vars + <= { + "state_fips", + "congressional_district_geoid", + } + ): all_cds.append(s) print(f"✓ Found {len(all_cds)} congressional/delegate districts") @@ -151,7 +161,9 @@ def validate_geographic_hierarchy(session): wyoming_cds.append(child) if len(wyoming_cds) != 1: - errors.append(f"ERROR: Wyoming should have 1 CD, found {len(wyoming_cds)}") + errors.append( + f"ERROR: Wyoming should have 1 CD, found {len(wyoming_cds)}" + ) else: print(f"✓ Wyoming has correct number of CDs: 1") @@ -175,7 +187,9 @@ def validate_geographic_hierarchy(session): for cd in wrong_parent_cds[:5]: errors.append(f" - {cd.notes}") else: - print("✓ No congressional districts incorrectly parented to Wyoming") + print( + "✓ No congressional districts incorrectly parented to Wyoming" + ) return errors @@ -226,7 +240,9 @@ def validate_demographic_strata(session): if actual == expected_total: print(f"✓ {domain}: {actual} strata") elif actual == 0: - errors.append(f"ERROR: {domain} has no strata, expected {expected_total}") + errors.append( + f"ERROR: {domain} has no strata, expected {expected_total}" + ) else: errors.append( f"WARNING: {domain} has {actual} strata, expected {expected_total}" @@ -306,12 +322,18 @@ def validate_constraint_uniqueness(session): else: hash_counts[stratum.definition_hash] = [stratum] - duplicates = {h: strata for h, strata in hash_counts.items() if len(strata) > 1} + duplicates = { + h: strata for h, strata in hash_counts.items() if len(strata) > 1 + } if duplicates: - errors.append(f"ERROR: Found {len(duplicates)} duplicate definition_hashes") + errors.append( + f"ERROR: Found {len(duplicates)} duplicate definition_hashes" + ) for hash_val, strata in list(duplicates.items())[:3]: # Show first 3 - errors.append(f" Hash {hash_val[:10]}... appears {len(strata)} times:") + errors.append( + f" Hash {hash_val[:10]}... appears {len(strata)} times:" + ) for s in strata[:3]: errors.append(f" - ID {s.stratum_id}: {s.notes[:50]}") else: @@ -323,7 +345,9 @@ def validate_constraint_uniqueness(session): def main(): """Run all validation checks""" - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) all_errors = [] diff --git a/policyengine_us_data/geography/__init__.py b/policyengine_us_data/geography/__init__.py index f2006819..0bcc73f0 100644 --- a/policyengine_us_data/geography/__init__.py +++ b/policyengine_us_data/geography/__init__.py @@ -2,7 +2,9 @@ import pandas as pd import os -ZIP_CODE_DATASET_PATH = Path(__file__).parent.parent / "geography" / "zip_codes.csv.gz" +ZIP_CODE_DATASET_PATH = ( + Path(__file__).parent.parent / "geography" / "zip_codes.csv.gz" +) # Avoid circular import error when -us-data is initialized if os.path.exists(ZIP_CODE_DATASET_PATH): diff --git a/policyengine_us_data/geography/county_fips.py b/policyengine_us_data/geography/county_fips.py index 6bb2b9e9..3e5ac518 100644 --- a/policyengine_us_data/geography/county_fips.py +++ b/policyengine_us_data/geography/county_fips.py @@ -21,9 +21,7 @@ def generate_county_fips_2020_dataset(): # COUNTYFP - Three-digit county portion of FIPS (001 for Autauga County, AL, if STATEFP is 01) # COUNTYNAME - County name - COUNTY_FIPS_2020_URL = ( - "https://www2.census.gov/geo/docs/reference/codes2020/national_county2020.txt" - ) + COUNTY_FIPS_2020_URL = "https://www2.census.gov/geo/docs/reference/codes2020/national_county2020.txt" # Download the base tab-delimited data file response = requests.get(COUNTY_FIPS_2020_URL) @@ -70,7 +68,9 @@ def generate_county_fips_2020_dataset(): csv_buffer = BytesIO() # Save CSV into buffer object and reset pointer - county_fips.to_csv(csv_buffer, index=False, compression="gzip", encoding="utf-8") + county_fips.to_csv( + csv_buffer, index=False, compression="gzip", encoding="utf-8" + ) csv_buffer.seek(0) # Upload to Hugging Face diff --git a/policyengine_us_data/geography/create_zip_code_dataset.py b/policyengine_us_data/geography/create_zip_code_dataset.py index 981b5de5..eb154cf7 100644 --- a/policyengine_us_data/geography/create_zip_code_dataset.py +++ b/policyengine_us_data/geography/create_zip_code_dataset.py @@ -51,5 +51,7 @@ zcta.set_index("zcta").population[zip_code.zcta].values / zip_code.groupby("zcta").zip_code.count()[zip_code.zcta].values ) -zip_code["county"] = zcta_to_county.set_index("zcta").county[zip_code.zcta].values +zip_code["county"] = ( + zcta_to_county.set_index("zcta").county[zip_code.zcta].values +) zip_code.to_csv("zip_codes.csv", compression="gzip") diff --git a/policyengine_us_data/parameters/__init__.py b/policyengine_us_data/parameters/__init__.py index dc385f8e..2fcddb5a 100644 --- a/policyengine_us_data/parameters/__init__.py +++ b/policyengine_us_data/parameters/__init__.py @@ -65,6 +65,8 @@ def load_take_up_rate(variable_name: str, year: int = 2018): break if applicable_value is None: - raise ValueError(f"No take-up rate found for {variable_name} in {year}") + raise ValueError( + f"No take-up rate found for {variable_name} in {year}" + ) return applicable_value diff --git a/policyengine_us_data/storage/calibration_targets/audit_county_enum.py b/policyengine_us_data/storage/calibration_targets/audit_county_enum.py index fcaf443f..4849a10e 100644 --- a/policyengine_us_data/storage/calibration_targets/audit_county_enum.py +++ b/policyengine_us_data/storage/calibration_targets/audit_county_enum.py @@ -109,7 +109,9 @@ def print_categorized_report(invalid_entries, county_to_states): print("\n" + "=" * 60) print("WRONG STATE ASSIGNMENTS") print("=" * 60) - for name, wrong_state, correct_states in sorted(invalid_entries["wrong_state"]): + for name, wrong_state, correct_states in sorted( + invalid_entries["wrong_state"] + ): print(f" {name}") print(f" Listed as: {wrong_state}") print(f" Actually exists in: {', '.join(sorted(correct_states))}") diff --git a/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py b/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py index f2b634e0..6f55e3f7 100644 --- a/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py +++ b/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py @@ -78,7 +78,9 @@ def build_block_cd_distributions(): # Create CD geoid in our format: state_fips * 100 + district # Examples: AL-1 = 101, NY-10 = 3610, DC = 1198 - df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype(int) + df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype( + int + ) # Step 4: Calculate P(block|CD) print("\nCalculating block probabilities...") @@ -95,7 +97,9 @@ def build_block_cd_distributions(): output = df[["cd_geoid", "GEOID", "probability"]].rename( columns={"GEOID": "block_geoid"} ) - output = output.sort_values(["cd_geoid", "probability"], ascending=[True, False]) + output = output.sort_values( + ["cd_geoid", "probability"], ascending=[True, False] + ) # Step 6: Save as gzipped CSV (parquet requires pyarrow) output_path = STORAGE_FOLDER / "block_cd_distributions.csv.gz" diff --git a/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py b/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py index ed0d8cc1..418e725f 100644 --- a/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py +++ b/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py @@ -60,7 +60,9 @@ def download_state_baf(state_fips: str, state_abbr: str) -> dict: ) # Place (City/CDP) - place_file = f"BlockAssign_ST{state_fips}_{state_abbr}_INCPLACE_CDP.txt" + place_file = ( + f"BlockAssign_ST{state_fips}_{state_abbr}_INCPLACE_CDP.txt" + ) if place_file in z.namelist(): df = pd.read_csv(z.open(place_file), sep="|", dtype=str) results["place"] = df.rename( @@ -166,17 +168,23 @@ def build_block_crosswalk(): # Merge other geographies if "sldl" in bafs: - df = df.merge(bafs["sldl"], on="block_geoid", how="left") + df = df.merge( + bafs["sldl"], on="block_geoid", how="left" + ) else: df["sldl"] = None if "place" in bafs: - df = df.merge(bafs["place"], on="block_geoid", how="left") + df = df.merge( + bafs["place"], on="block_geoid", how="left" + ) else: df["place_fips"] = None if "vtd" in bafs: - df = df.merge(bafs["vtd"], on="block_geoid", how="left") + df = df.merge( + bafs["vtd"], on="block_geoid", how="left" + ) else: df["vtd"] = None diff --git a/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py b/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py index 2c91f1ca..ba68a556 100644 --- a/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py +++ b/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py @@ -126,11 +126,15 @@ def build_county_cd_distributions(): # Create CD geoid in our format: state_fips * 100 + district # Examples: AL-1 = 101, NY-10 = 3610, DC = 1198 - df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype(int) + df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype( + int + ) # Step 4: Aggregate by (CD, county) print("\nAggregating population by CD and county...") - cd_county_pop = df.groupby(["cd_geoid", "county_fips"])["POP20"].sum().reset_index() + cd_county_pop = ( + df.groupby(["cd_geoid", "county_fips"])["POP20"].sum().reset_index() + ) print(f" Unique CD-county pairs: {len(cd_county_pop):,}") # Step 5: Calculate P(county|CD) @@ -147,7 +151,9 @@ def build_county_cd_distributions(): # Step 6: Map county FIPS to enum names print("\nMapping county FIPS to enum names...") fips_to_enum = build_county_fips_to_enum_mapping() - cd_county_pop["county_name"] = cd_county_pop["county_fips"].map(fips_to_enum) + cd_county_pop["county_name"] = cd_county_pop["county_fips"].map( + fips_to_enum + ) # Check for unmapped counties unmapped = cd_county_pop[cd_county_pop["county_name"].isna()] @@ -171,7 +177,9 @@ def build_county_cd_distributions(): # Step 8: Save CSV output = cd_county_pop[["cd_geoid", "county_name", "probability"]] - output = output.sort_values(["cd_geoid", "probability"], ascending=[True, False]) + output = output.sort_values( + ["cd_geoid", "probability"], ascending=[True, False] + ) output_path = STORAGE_FOLDER / "county_cd_distributions.csv" output.to_csv(output_path, index=False) diff --git a/policyengine_us_data/storage/calibration_targets/make_district_mapping.py b/policyengine_us_data/storage/calibration_targets/make_district_mapping.py index bfb4936e..2b930a2d 100644 --- a/policyengine_us_data/storage/calibration_targets/make_district_mapping.py +++ b/policyengine_us_data/storage/calibration_targets/make_district_mapping.py @@ -91,7 +91,9 @@ def fetch_block_to_district_map(congress: int) -> pd.DataFrame: return bef[["GEOID", f"CD{congress}"]] else: - raise ValueError(f"Congress {congress} is not supported by this function.") + raise ValueError( + f"Congress {congress} is not supported by this function." + ) def fetch_block_population(state) -> pd.DataFrame: @@ -143,7 +145,9 @@ def fetch_block_population(state) -> pd.DataFrame: geo_df = pd.DataFrame(geo_records, columns=["LOGRECNO", "GEOID"]) # ---------------- P-file: pull total-population cell ---------------------- - p1_records = [(p[4], int(p[5])) for p in map(lambda x: x.split("|"), p1_lines)] + p1_records = [ + (p[4], int(p[5])) for p in map(lambda x: x.split("|"), p1_lines) + ] p1_df = pd.DataFrame(p1_records, columns=["LOGRECNO", "P0010001"]) # ---------------- Merge & finish ----------------------------------------- diff --git a/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py b/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py index 3199a56a..da8b5412 100644 --- a/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py +++ b/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py @@ -42,9 +42,13 @@ def pull_hardcoded_targets(): "VARIABLE": list(HARD_CODED_TOTALS.keys()), "VALUE": list(HARD_CODED_TOTALS.values()), "IS_COUNT": [0.0] - * len(HARD_CODED_TOTALS), # All values are monetary amounts, not counts + * len( + HARD_CODED_TOTALS + ), # All values are monetary amounts, not counts "BREAKDOWN_VARIABLE": [np.nan] - * len(HARD_CODED_TOTALS), # No breakdown variable for hardcoded targets + * len( + HARD_CODED_TOTALS + ), # No breakdown variable for hardcoded targets "LOWER_BOUND": [np.nan] * len(HARD_CODED_TOTALS), "UPPER_BOUND": [np.nan] * len(HARD_CODED_TOTALS), } diff --git a/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py b/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py index 202286e7..1830bdb3 100644 --- a/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py +++ b/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py @@ -84,9 +84,7 @@ def extract_usda_snap_data(year=2023): session.headers.update(headers) # Try to visit the main page first to get any necessary cookies - main_page = ( - "https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap" - ) + main_page = "https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap" try: session.get(main_page, timeout=30) except: @@ -169,7 +167,9 @@ def extract_usda_snap_data(year=2023): .reset_index(drop=True) ) df_states["GEO_ID"] = "0400000US" + df_states["STATE_FIPS"] - df_states["GEO_NAME"] = "state_" + df_states["State"].map(STATE_NAME_TO_ABBREV) + df_states["GEO_NAME"] = "state_" + df_states["State"].map( + STATE_NAME_TO_ABBREV + ) count_df = df_states[["GEO_ID", "GEO_NAME"]].copy() count_df["VALUE"] = df_states["Households"] diff --git a/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py b/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py index ce6d9f88..59050a1b 100644 --- a/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py +++ b/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py @@ -129,17 +129,26 @@ def pull_national_soi_variable( national_df: Optional[pd.DataFrame] = None, ) -> pd.DataFrame: """Download and save national AGI totals.""" - df = pd.read_excel("https://www.irs.gov/pub/irs-soi/22in54us.xlsx", skiprows=7) + df = pd.read_excel( + "https://www.irs.gov/pub/irs-soi/22in54us.xlsx", skiprows=7 + ) assert ( - np.abs(df.iloc[soi_variable_ident, 1] - df.iloc[soi_variable_ident, 2:12].sum()) + np.abs( + df.iloc[soi_variable_ident, 1] + - df.iloc[soi_variable_ident, 2:12].sum() + ) < 100 ), "Row 0 doesn't add up — check the file." agi_values = df.iloc[soi_variable_ident, 2:12].astype(int).to_numpy() - agi_values = np.concatenate([agi_values[:8], [agi_values[8] + agi_values[9]]]) + agi_values = np.concatenate( + [agi_values[:8], [agi_values[8] + agi_values[9]]] + ) - agi_brackets = [AGI_STUB_TO_BAND[i] for i in range(1, len(SOI_COLUMNS) + 1)] + agi_brackets = [ + AGI_STUB_TO_BAND[i] for i in range(1, len(SOI_COLUMNS) + 1) + ] result = pd.DataFrame( { @@ -152,7 +161,9 @@ def pull_national_soi_variable( ) # final column order - result = result[["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"]] + result = result[ + ["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"] + ] result["IS_COUNT"] = int(is_count) result["VARIABLE"] = variable_name @@ -175,7 +186,9 @@ def pull_state_soi_variable( state_df: Optional[pd.DataFrame] = None, ) -> pd.DataFrame: """Download and save state AGI totals.""" - df = pd.read_csv("https://www.irs.gov/pub/irs-soi/22in55cmcsv.csv", thousands=",") + df = pd.read_csv( + "https://www.irs.gov/pub/irs-soi/22in55cmcsv.csv", thousands="," + ) merged = ( df[df["AGI_STUB"].isin([9, 10])] @@ -198,11 +211,17 @@ def pull_state_soi_variable( ["GEO_ID", "GEO_NAME", "agi_bracket", soi_variable_ident], ].rename(columns={soi_variable_ident: "VALUE"}) - result["LOWER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][0]) - result["UPPER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][1]) + result["LOWER_BOUND"] = result["agi_bracket"].map( + lambda b: AGI_BOUNDS[b][0] + ) + result["UPPER_BOUND"] = result["agi_bracket"].map( + lambda b: AGI_BOUNDS[b][1] + ) # final column order - result = result[["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"]] + result = result[ + ["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"] + ] result["IS_COUNT"] = int(is_count) result["VARIABLE"] = variable_name @@ -230,7 +249,9 @@ def pull_district_soi_variable( df = df[df["agi_stub"] != 0] df["STATEFIPS"] = df["STATEFIPS"].astype(int).astype(str).str.zfill(2) - df["CONG_DISTRICT"] = df["CONG_DISTRICT"].astype(int).astype(str).str.zfill(2) + df["CONG_DISTRICT"] = ( + df["CONG_DISTRICT"].astype(int).astype(str).str.zfill(2) + ) if SOI_DISTRICT_TAX_YEAR >= 2024: raise RuntimeError( f"SOI tax year {SOI_DISTRICT_TAX_YEAR} may need " @@ -267,8 +288,12 @@ def pull_district_soi_variable( ] ].rename(columns={soi_variable_ident: "VALUE"}) - result["LOWER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][0]) - result["UPPER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][1]) + result["LOWER_BOUND"] = result["agi_bracket"].map( + lambda b: AGI_BOUNDS[b][0] + ) + result["UPPER_BOUND"] = result["agi_bracket"].map( + lambda b: AGI_BOUNDS[b][1] + ) # if redistrict: # result = apply_redistricting(result, variable_name) @@ -283,23 +308,25 @@ def pull_district_soi_variable( # Check that all GEO_IDs are valid produced_codes = set(result["GEO_ID"]) invalid_codes = produced_codes - valid_district_codes - assert not invalid_codes, ( - f"Invalid district codes after redistricting: {invalid_codes}" - ) + assert ( + not invalid_codes + ), f"Invalid district codes after redistricting: {invalid_codes}" # Check we have exactly 436 districts - assert len(produced_codes) == 436, ( - f"Expected 436 districts after redistricting, got {len(produced_codes)}" - ) + assert ( + len(produced_codes) == 436 + ), f"Expected 436 districts after redistricting, got {len(produced_codes)}" # Check that all GEO_IDs successfully mapped to names missing_names = result[result["GEO_NAME"].isna()]["GEO_ID"].unique() - assert len(missing_names) == 0, ( - f"GEO_IDs without names in ID_TO_NAME mapping: {missing_names}" - ) + assert ( + len(missing_names) == 0 + ), f"GEO_IDs without names in ID_TO_NAME mapping: {missing_names}" # final column order - result = result[["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"]] + result = result[ + ["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"] + ] result["IS_COUNT"] = int(is_count) result["VARIABLE"] = variable_name @@ -430,11 +457,15 @@ def combine_geography_levels(districts: Optional[bool] = False) -> None: ) # Get state totals indexed by STATEFIPS - state_totals = state.loc[state_mask].set_index("STATEFIPS")["VALUE"] + state_totals = state.loc[state_mask].set_index("STATEFIPS")[ + "VALUE" + ] # Get district totals grouped by STATEFIPS district_totals = ( - district.loc[district_mask].groupby("STATEFIPS")["VALUE"].sum() + district.loc[district_mask] + .groupby("STATEFIPS")["VALUE"] + .sum() ) # Check and rescale districts for each state @@ -449,8 +480,12 @@ def combine_geography_levels(districts: Optional[bool] = False) -> None: f"Districts' sum does not match {fips} state total for {variable}/{count_type} " f"in bracket [{lower}, {upper}]. Rescaling district targets." ) - rescale_mask = district_mask & (district["STATEFIPS"] == fips) - district.loc[rescale_mask, "VALUE"] *= s_total / d_total + rescale_mask = district_mask & ( + district["STATEFIPS"] == fips + ) + district.loc[rescale_mask, "VALUE"] *= ( + s_total / d_total + ) # Combine all data combined = pd.concat( diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index 7af0da04..d4f7a070 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -103,7 +103,9 @@ def _check_group_has_data(f, name): ) # At least one income group must have data - has_income = any(_check_group_has_data(f, g) for g in INCOME_GROUPS) + has_income = any( + _check_group_has_data(f, g) for g in INCOME_GROUPS + ) if not has_income: errors.append( f"No income data found. Need at least one of " @@ -125,7 +127,9 @@ def _check_group_has_data(f, name): try: dataset_cls = FILENAME_TO_DATASET.get(filename) if dataset_cls is None: - raise DatasetValidationError(f"No dataset class registered for {filename}") + raise DatasetValidationError( + f"No dataset class registered for {filename}" + ) sim = Microsimulation(dataset=dataset_cls) year = 2024 diff --git a/policyengine_us_data/tests/test_calibration/conftest.py b/policyengine_us_data/tests/test_calibration/conftest.py index 0698cef0..35449156 100644 --- a/policyengine_us_data/tests/test_calibration/conftest.py +++ b/policyengine_us_data/tests/test_calibration/conftest.py @@ -13,4 +13,6 @@ def db_uri(): @pytest.fixture(scope="module") def dataset_path(): - return str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") + return str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" + ) diff --git a/policyengine_us_data/tests/test_calibration/create_test_fixture.py b/policyengine_us_data/tests/test_calibration/create_test_fixture.py index 2fadeeeb..00334734 100644 --- a/policyengine_us_data/tests/test_calibration/create_test_fixture.py +++ b/policyengine_us_data/tests/test_calibration/create_test_fixture.py @@ -30,7 +30,9 @@ def create_test_fixture(): # Household-level arrays household_ids = np.arange(N_HOUSEHOLDS, dtype=np.int32) - household_weights = np.random.uniform(500, 3000, N_HOUSEHOLDS).astype(np.float32) + household_weights = np.random.uniform(500, 3000, N_HOUSEHOLDS).astype( + np.float32 + ) # Assign households to states (use NC=37 and AK=2 for testing) # 40 households in NC, 10 in AK @@ -100,14 +102,18 @@ def create_test_fixture(): f["household_id"].create_dataset(TIME_PERIOD, data=household_ids) f.create_group("household_weight") - f["household_weight"].create_dataset(TIME_PERIOD, data=household_weights) + f["household_weight"].create_dataset( + TIME_PERIOD, data=household_weights + ) # Person variables f.create_group("person_id") f["person_id"].create_dataset(TIME_PERIOD, data=person_ids) f.create_group("person_household_id") - f["person_household_id"].create_dataset(TIME_PERIOD, data=person_household_ids) + f["person_household_id"].create_dataset( + TIME_PERIOD, data=person_household_ids + ) f.create_group("person_weight") f["person_weight"].create_dataset(TIME_PERIOD, data=person_weights) @@ -116,14 +122,18 @@ def create_test_fixture(): f["age"].create_dataset(TIME_PERIOD, data=ages) f.create_group("employment_income") - f["employment_income"].create_dataset(TIME_PERIOD, data=employment_income) + f["employment_income"].create_dataset( + TIME_PERIOD, data=employment_income + ) # Tax unit f.create_group("tax_unit_id") f["tax_unit_id"].create_dataset(TIME_PERIOD, data=tax_unit_ids) f.create_group("person_tax_unit_id") - f["person_tax_unit_id"].create_dataset(TIME_PERIOD, data=person_tax_unit_ids) + f["person_tax_unit_id"].create_dataset( + TIME_PERIOD, data=person_tax_unit_ids + ) f.create_group("tax_unit_weight") f["tax_unit_weight"].create_dataset(TIME_PERIOD, data=tax_unit_weights) @@ -133,7 +143,9 @@ def create_test_fixture(): f["spm_unit_id"].create_dataset(TIME_PERIOD, data=spm_unit_ids) f.create_group("person_spm_unit_id") - f["person_spm_unit_id"].create_dataset(TIME_PERIOD, data=person_spm_unit_ids) + f["person_spm_unit_id"].create_dataset( + TIME_PERIOD, data=person_spm_unit_ids + ) f.create_group("spm_unit_weight") f["spm_unit_weight"].create_dataset(TIME_PERIOD, data=spm_unit_weights) @@ -143,7 +155,9 @@ def create_test_fixture(): f["family_id"].create_dataset(TIME_PERIOD, data=family_ids) f.create_group("person_family_id") - f["person_family_id"].create_dataset(TIME_PERIOD, data=person_family_ids) + f["person_family_id"].create_dataset( + TIME_PERIOD, data=person_family_ids + ) f.create_group("family_weight") f["family_weight"].create_dataset(TIME_PERIOD, data=family_weights) @@ -158,7 +172,9 @@ def create_test_fixture(): ) f.create_group("marital_unit_weight") - f["marital_unit_weight"].create_dataset(TIME_PERIOD, data=marital_unit_weights) + f["marital_unit_weight"].create_dataset( + TIME_PERIOD, data=marital_unit_weights + ) # Geography (household level) f.create_group("state_fips") diff --git a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py index 81cd925d..122be1fb 100644 --- a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py +++ b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py @@ -15,7 +15,9 @@ from policyengine_us_data.storage import STORAGE_FOLDER -DATASET_PATH = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") +DATASET_PATH = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db") DB_URI = f"sqlite:///{DB_PATH}" @@ -42,7 +44,9 @@ def matrix_result(): sim = Microsimulation(dataset=DATASET_PATH) n_records = sim.calculate("household_id").values.shape[0] - geography = assign_random_geography(n_records, n_clones=N_CLONES, seed=SEED) + geography = assign_random_geography( + n_records, n_clones=N_CLONES, seed=SEED + ) builder = UnifiedMatrixBuilder( db_uri=DB_URI, time_period=2024, @@ -54,7 +58,9 @@ def matrix_result(): target_filter={"domain_variables": ["snap", "medicaid"]}, ) X_csc = X_sparse.tocsc() - national_rows = targets_df[targets_df["geo_level"] == "national"].index.values + national_rows = targets_df[ + targets_df["geo_level"] == "national" + ].index.values district_targets = targets_df[targets_df["geo_level"] == "district"] record_idx = None for ri in range(n_records): @@ -180,7 +186,11 @@ def test_clone_visible_only_to_own_cd(self, matrix_result): vals_0 = X_csc[:, col_0].toarray().ravel() same_state_other_cd = district_targets[ - (district_targets["geographic_id"].apply(lambda g: g.startswith(state_0))) + ( + district_targets["geographic_id"].apply( + lambda g: g.startswith(state_0) + ) + ) & (district_targets["geographic_id"] != cd_0) ] @@ -210,7 +220,9 @@ def test_clone_nonzero_for_own_cd(self, matrix_result): X_csc = X.tocsc() vals_0 = X_csc[:, col_0].toarray().ravel() - any_nonzero = any(vals_0[row.name] != 0 for _, row in own_cd_targets.iterrows()) - assert any_nonzero, ( - f"Clone 0 should have at least one non-zero entry for its own CD {cd_0}" + any_nonzero = any( + vals_0[row.name] != 0 for _, row in own_cd_targets.iterrows() ) + assert ( + any_nonzero + ), f"Clone 0 should have at least one non-zero entry for its own CD {cd_0}" diff --git a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py index 9eb1b6f5..93bc5473 100644 --- a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py +++ b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py @@ -69,7 +69,9 @@ def test_loads_and_normalizes(self, tmp_path): "policyengine_us_data.calibration.clone_and_assign.STORAGE_FOLDER", tmp_path, ): - blocks, cds, states, probs = load_global_block_distribution.__wrapped__() + blocks, cds, states, probs = ( + load_global_block_distribution.__wrapped__() + ) assert len(blocks) == 9 np.testing.assert_almost_equal(probs.sum(), 1.0) @@ -138,11 +140,12 @@ def test_no_cd_collisions_across_clones(self, mock_load): r = assign_random_geography(n_records=100, n_clones=3, seed=42) for rec in range(r.n_records): rec_cds = [ - r.cd_geoid[clone * r.n_records + rec] for clone in range(r.n_clones) + r.cd_geoid[clone * r.n_records + rec] + for clone in range(r.n_clones) ] - assert len(rec_cds) == len(set(rec_cds)), ( - f"Record {rec} has duplicate CDs: {rec_cds}" - ) + assert len(rec_cds) == len( + set(rec_cds) + ), f"Record {rec} has duplicate CDs: {rec_cds}" def test_missing_file_raises(self, tmp_path): fake = tmp_path / "nonexistent" diff --git a/policyengine_us_data/tests/test_calibration/test_county_assignment.py b/policyengine_us_data/tests/test_calibration/test_county_assignment.py index d9b64991..03d7342d 100644 --- a/policyengine_us_data/tests/test_calibration/test_county_assignment.py +++ b/policyengine_us_data/tests/test_calibration/test_county_assignment.py @@ -47,7 +47,9 @@ def test_ny_cd_gets_ny_counties(self): for idx in result: county_name = County._member_names_[idx] # Should end with _NY - assert county_name.endswith("_NY"), f"Got non-NY county: {county_name}" + assert county_name.endswith( + "_NY" + ), f"Got non-NY county: {county_name}" def test_ca_cd_gets_ca_counties(self): """Verify CA CDs get CA counties.""" @@ -56,7 +58,9 @@ def test_ca_cd_gets_ca_counties(self): for idx in result: county_name = County._member_names_[idx] - assert county_name.endswith("_CA"), f"Got non-CA county: {county_name}" + assert county_name.endswith( + "_CA" + ), f"Got non-CA county: {county_name}" class TestCountyIndex: diff --git a/policyengine_us_data/tests/test_calibration/test_puf_impute.py b/policyengine_us_data/tests/test_calibration/test_puf_impute.py index d803486e..1bce3cf7 100644 --- a/policyengine_us_data/tests/test_calibration/test_puf_impute.py +++ b/policyengine_us_data/tests/test_calibration/test_puf_impute.py @@ -150,7 +150,9 @@ def test_reduces_to_target(self): rng.uniform(500_000, 5_000_000, size=250), ] ) - idx = _stratified_subsample_index(income, target_n=10_000, top_pct=99.5) + idx = _stratified_subsample_index( + income, target_n=10_000, top_pct=99.5 + ) assert len(idx) == 10_000 def test_preserves_top_earners(self): @@ -164,7 +166,9 @@ def test_preserves_top_earners(self): threshold = np.percentile(income, 99.5) n_top = (income >= threshold).sum() - idx = _stratified_subsample_index(income, target_n=10_000, top_pct=99.5) + idx = _stratified_subsample_index( + income, target_n=10_000, top_pct=99.5 + ) selected_income = income[idx] n_top_selected = (selected_income >= threshold).sum() assert n_top_selected == n_top diff --git a/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py b/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py index 5b635c79..d8740d16 100644 --- a/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py +++ b/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py @@ -54,8 +54,14 @@ def _make_mock_data(n_persons=20, n_households=5, time_period=2024): "person_household_id": {time_period: hh_ids_person}, "person_tax_unit_id": {time_period: hh_ids_person.copy()}, "person_spm_unit_id": {time_period: hh_ids_person.copy()}, - "age": {time_period: rng.integers(18, 80, size=n_persons).astype(np.float32)}, - "is_male": {time_period: rng.integers(0, 2, size=n_persons).astype(np.float32)}, + "age": { + time_period: rng.integers(18, 80, size=n_persons).astype( + np.float32 + ) + }, + "is_male": { + time_period: rng.integers(0, 2, size=n_persons).astype(np.float32) + }, "household_weight": {time_period: np.ones(n_households) * 1000}, "employment_income": { time_period: rng.uniform(0, 100_000, n_persons).astype(np.float32) @@ -65,7 +71,9 @@ def _make_mock_data(n_persons=20, n_households=5, time_period=2024): }, } for var in CPS_RETIREMENT_VARIABLES: - data[var] = {time_period: rng.uniform(0, 5000, n_persons).astype(np.float32)} + data[var] = { + time_period: rng.uniform(0, 5000, n_persons).astype(np.float32) + } return data @@ -129,9 +137,9 @@ class TestConstants: def test_retirement_vars_not_in_imputed(self): """Retirement vars must NOT be in IMPUTED_VARIABLES.""" for var in CPS_RETIREMENT_VARIABLES: - assert var not in IMPUTED_VARIABLES, ( - f"{var} should not be in IMPUTED_VARIABLES" - ) + assert ( + var not in IMPUTED_VARIABLES + ), f"{var} should not be in IMPUTED_VARIABLES" def test_retirement_vars_not_in_overridden(self): for var in CPS_RETIREMENT_VARIABLES: @@ -161,12 +169,14 @@ def test_retirement_predictors_include_demographics(self): def test_income_predictors_in_imputed_variables(self): """All income predictors must be available from PUF QRF.""" for var in RETIREMENT_INCOME_PREDICTORS: - assert var in IMPUTED_VARIABLES, ( - f"{var} not in IMPUTED_VARIABLES — won't be in puf_imputations" - ) + assert ( + var in IMPUTED_VARIABLES + ), f"{var} not in IMPUTED_VARIABLES — won't be in puf_imputations" def test_predictors_are_combined_lists(self): - expected = RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS + expected = ( + RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS + ) assert RETIREMENT_PREDICTORS == expected @@ -258,12 +268,18 @@ def _setup(self): self.puf_imputations = { "employment_income": emp, "self_employment_income": se, - "taxable_interest_income": rng.uniform(0, 5_000, self.n).astype(np.float32), + "taxable_interest_income": rng.uniform(0, 5_000, self.n).astype( + np.float32 + ), "qualified_dividend_income": rng.uniform(0, 3_000, self.n).astype( np.float32 ), - "taxable_pension_income": rng.uniform(0, 20_000, self.n).astype(np.float32), - "social_security": rng.uniform(0, 15_000, self.n).astype(np.float32), + "taxable_pension_income": rng.uniform(0, 20_000, self.n).astype( + np.float32 + ), + "social_security": rng.uniform(0, 15_000, self.n).astype( + np.float32 + ), } self.cps_df = _make_cps_df(self.n, rng) @@ -301,7 +317,10 @@ def _uniform_preds(self, value): def _random_preds(self, low, high, seed=99): rng = np.random.default_rng(seed) return pd.DataFrame( - {var: rng.uniform(low, high, self.n) for var in CPS_RETIREMENT_VARIABLES} + { + var: rng.uniform(low, high, self.n) + for var in CPS_RETIREMENT_VARIABLES + } ) def test_returns_all_retirement_vars(self): @@ -346,23 +365,27 @@ def test_401k_zero_when_no_wages(self): "traditional_401k_contributions", "roth_401k_contributions", ): - assert np.all(result[var][zero_wage] == 0), ( - f"{var} should be 0 when employment_income is 0" - ) + assert np.all( + result[var][zero_wage] == 0 + ), f"{var} should be 0 when employment_income is 0" def test_se_pension_zero_when_no_se_income(self): result = self._call_with_mocks(self._uniform_preds(5_000.0)) zero_se = self.puf_imputations["self_employment_income"] == 0 assert zero_se.sum() == 20 - assert np.all(result["self_employed_pension_contributions"][zero_se] == 0) + assert np.all( + result["self_employed_pension_contributions"][zero_se] == 0 + ) def test_catch_up_age_threshold(self): """Records age >= 50 get higher caps than younger.""" - self.cps_df["age"] = np.concatenate([np.full(25, 30.0), np.full(25, 55.0)]) - # All have positive income - self.puf_imputations["employment_income"] = np.full(self.n, 100_000.0).astype( - np.float32 + self.cps_df["age"] = np.concatenate( + [np.full(25, 30.0), np.full(25, 55.0)] ) + # All have positive income + self.puf_imputations["employment_income"] = np.full( + self.n, 100_000.0 + ).astype(np.float32) lim = _get_retirement_limits(self.time_period) val = float(lim["401k"]) + 1000 # 24000 @@ -379,7 +402,9 @@ def test_catch_up_age_threshold(self): def test_ira_catch_up_threshold(self): """IRA catch-up also works for age >= 50.""" - self.cps_df["age"] = np.concatenate([np.full(25, 30.0), np.full(25, 55.0)]) + self.cps_df["age"] = np.concatenate( + [np.full(25, 30.0), np.full(25, 55.0)] + ) lim = _get_retirement_limits(self.time_period) val = float(lim["ira"]) + 500 # 7500 @@ -405,7 +430,9 @@ def test_401k_nonzero_for_positive_wages(self): def test_se_pension_nonzero_for_positive_se(self): result = self._call_with_mocks(self._uniform_preds(5_000.0)) pos_se = self.puf_imputations["self_employment_income"] > 0 - assert np.all(result["self_employed_pension_contributions"][pos_se] > 0) + assert np.all( + result["self_employed_pension_contributions"][pos_se] > 0 + ) def test_se_pension_capped_at_rate_times_income(self): """SE pension should not exceed 25% of SE income.""" @@ -431,7 +458,9 @@ def test_qrf_failure_returns_zeros(self): # Make a QRF that crashes on fit_predict mock_qrf_cls = MagicMock() - mock_qrf_cls.return_value.fit_predict.side_effect = RuntimeError("QRF exploded") + mock_qrf_cls.return_value.fit_predict.side_effect = RuntimeError( + "QRF exploded" + ) qrf_mod = sys.modules["microimpute.models.qrf"] old_qrf = getattr(qrf_mod, "QRF", None) @@ -457,7 +486,9 @@ def test_training_data_failure_returns_zeros(self): import sys mock_sim = MagicMock() - mock_sim.calculate_dataframe.side_effect = ValueError("missing variable") + mock_sim.calculate_dataframe.side_effect = ValueError( + "missing variable" + ) qrf_mod = sys.modules["microimpute.models.qrf"] old_qrf = getattr(qrf_mod, "QRF", None) @@ -507,7 +538,9 @@ def test_retirement_vars_use_imputed_when_available(self): state_fips = np.array([1, 2, 36, 6, 48]) n = 20 - fake_retirement = {var: np.full(n, 999.0) for var in CPS_RETIREMENT_VARIABLES} + fake_retirement = { + var: np.full(n, 999.0) for var in CPS_RETIREMENT_VARIABLES + } with ( patch( @@ -548,8 +581,12 @@ def test_cps_half_unchanged_with_imputation(self): state_fips = np.array([1, 2, 36, 6, 48]) n = 20 - originals = {var: data[var][2024].copy() for var in CPS_RETIREMENT_VARIABLES} - fake_retirement = {var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES} + originals = { + var: data[var][2024].copy() for var in CPS_RETIREMENT_VARIABLES + } + fake_retirement = { + var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES + } with ( patch( @@ -580,7 +617,9 @@ def test_cps_half_unchanged_with_imputation(self): ) for var in CPS_RETIREMENT_VARIABLES: - np.testing.assert_array_equal(result[var][2024][:n], originals[var]) + np.testing.assert_array_equal( + result[var][2024][:n], originals[var] + ) def test_puf_half_gets_zero_retirement_for_zero_imputed(self): """When imputation returns zeros, PUF half should be zero.""" @@ -588,7 +627,9 @@ def test_puf_half_gets_zero_retirement_for_zero_imputed(self): state_fips = np.array([1, 2, 36, 6, 48]) n = 20 - fake_retirement = {var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES} + fake_retirement = { + var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES + } with ( patch( @@ -658,6 +699,6 @@ def test_401k_ira_from_policyengine_us(self): ours = _get_retirement_limits(year) pe = pe_limits(year) for key in ["401k", "401k_catch_up", "ira", "ira_catch_up"]: - assert ours[key] == pe[key], ( - f"Year {year} key {key}: {ours[key]} != {pe[key]}" - ) + assert ( + ours[key] == pe[key] + ), f"Year {year} key {key}: {ours[key]} != {pe[key]}" diff --git a/policyengine_us_data/tests/test_calibration/test_source_impute.py b/policyengine_us_data/tests/test_calibration/test_source_impute.py index 517a559e..c69ec653 100644 --- a/policyengine_us_data/tests/test_calibration/test_source_impute.py +++ b/policyengine_us_data/tests/test_calibration/test_source_impute.py @@ -71,7 +71,9 @@ def test_scf_variables_defined(self): def test_all_source_variables_defined(self): expected = ( - ACS_IMPUTED_VARIABLES + SIPP_IMPUTED_VARIABLES + SCF_IMPUTED_VARIABLES + ACS_IMPUTED_VARIABLES + + SIPP_IMPUTED_VARIABLES + + SCF_IMPUTED_VARIABLES ) assert ALL_SOURCE_VARIABLES == expected diff --git a/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py index 339dec4e..b4f4831d 100644 --- a/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py @@ -40,7 +40,9 @@ def _make_geography(n_hh, cds): ], dtype="U15", ) - state_fips_arr = np.array([int(cd) // 100 for cd in cd_geoid], dtype=np.int32) + state_fips_arr = np.array( + [int(cd) // 100 for cd in cd_geoid], dtype=np.int32 + ) county_fips = np.array([b[:5] for b in block_geoid], dtype="U5") return GeographyAssignment( block_geoid=block_geoid, @@ -152,15 +154,17 @@ def test_counties_match_state(self, stacked_result): state_fips = row["state_fips"] if state_fips == 37: - assert county.endswith("_NC"), ( - f"NC county should end with _NC: {county}" - ) + assert county.endswith( + "_NC" + ), f"NC county should end with _NC: {county}" elif state_fips == 2: - assert county.endswith("_AK"), ( - f"AK county should end with _AK: {county}" - ) + assert county.endswith( + "_AK" + ), f"AK county should end with _AK: {county}" - def test_household_count_matches_weights(self, stacked_result, test_weights): + def test_household_count_matches_weights( + self, stacked_result, test_weights + ): """Number of output households should match non-zero weights.""" hh_df = stacked_result["hh_df"] expected_households = (test_weights > 0).sum() @@ -218,30 +222,40 @@ class TestEntityReindexing: def test_family_ids_are_unique(self, stacked_sim): """Family IDs should be globally unique across all CDs.""" family_ids = stacked_sim.calculate("family_id", map_to="family").values - assert len(family_ids) == len(set(family_ids)), "Family IDs should be unique" + assert len(family_ids) == len( + set(family_ids) + ), "Family IDs should be unique" def test_tax_unit_ids_are_unique(self, stacked_sim): """Tax unit IDs should be globally unique.""" - tax_unit_ids = stacked_sim.calculate("tax_unit_id", map_to="tax_unit").values - assert len(tax_unit_ids) == len(set(tax_unit_ids)), ( - "Tax unit IDs should be unique" - ) + tax_unit_ids = stacked_sim.calculate( + "tax_unit_id", map_to="tax_unit" + ).values + assert len(tax_unit_ids) == len( + set(tax_unit_ids) + ), "Tax unit IDs should be unique" def test_spm_unit_ids_are_unique(self, stacked_sim): """SPM unit IDs should be globally unique.""" - spm_unit_ids = stacked_sim.calculate("spm_unit_id", map_to="spm_unit").values - assert len(spm_unit_ids) == len(set(spm_unit_ids)), ( - "SPM unit IDs should be unique" - ) + spm_unit_ids = stacked_sim.calculate( + "spm_unit_id", map_to="spm_unit" + ).values + assert len(spm_unit_ids) == len( + set(spm_unit_ids) + ), "SPM unit IDs should be unique" def test_person_family_id_matches_family_id(self, stacked_sim): """person_family_id should reference valid family_ids.""" person_family_ids = stacked_sim.calculate( "person_family_id", map_to="person" ).values - family_ids = set(stacked_sim.calculate("family_id", map_to="family").values) + family_ids = set( + stacked_sim.calculate("family_id", map_to="family").values + ) for pf_id in person_family_ids: - assert pf_id in family_ids, f"person_family_id {pf_id} not in family_ids" + assert ( + pf_id in family_ids + ), f"person_family_id {pf_id} not in family_ids" def test_family_ids_unique_across_cds(self, stacked_sim_with_overlap): """Same HH in different CDs should get different family_ids.""" @@ -252,9 +266,9 @@ def test_family_ids_unique_across_cds(self, stacked_sim_with_overlap): family_ids = sim.calculate("family_id", map_to="family").values expected_families = n_overlap * n_cds - assert len(family_ids) == expected_families, ( - f"Expected {expected_families} families, got {len(family_ids)}" - ) + assert ( + len(family_ids) == expected_families + ), f"Expected {expected_families} families, got {len(family_ids)}" assert len(set(family_ids)) == expected_families, ( f"Family IDs not unique: " f"{len(set(family_ids))} unique " diff --git a/policyengine_us_data/tests/test_calibration/test_target_config.py b/policyengine_us_data/tests/test_calibration/test_target_config.py index 377d3a64..b19fc94f 100644 --- a/policyengine_us_data/tests/test_calibration/test_target_config.py +++ b/policyengine_us_data/tests/test_calibration/test_target_config.py @@ -104,7 +104,9 @@ def test_domain_variable_matching(self, sample_targets): def test_matrix_and_names_stay_in_sync(self, sample_targets): df, X, names = sample_targets - config = {"exclude": [{"variable": "person_count", "geo_level": "national"}]} + config = { + "exclude": [{"variable": "person_count", "geo_level": "national"}] + } out_df, out_X, out_names = apply_target_config(df, X, names, config) assert out_X.shape[0] == len(out_df) assert len(out_names) == len(out_df) @@ -112,7 +114,9 @@ def test_matrix_and_names_stay_in_sync(self, sample_targets): def test_no_match_keeps_all(self, sample_targets): df, X, names = sample_targets - config = {"exclude": [{"variable": "nonexistent", "geo_level": "national"}]} + config = { + "exclude": [{"variable": "nonexistent", "geo_level": "national"}] + } out_df, out_X, out_names = apply_target_config(df, X, names, config) assert len(out_df) == len(df) assert out_X.shape[0] == X.shape[0] diff --git a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py index 28a3c906..d182db5a 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py @@ -78,8 +78,12 @@ class TestBlockSaltedDraws: def test_same_block_same_results(self): blocks = np.array(["370010001001001"] * 500) - d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) - d2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + d2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) np.testing.assert_array_equal(d1, d2) def test_different_blocks_different_results(self): @@ -98,8 +102,12 @@ def test_different_blocks_different_results(self): def test_different_vars_different_results(self): blocks = np.array(["370010001001001"] * 500) - d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) - d2 = compute_block_takeup_for_entities("takes_up_aca_if_eligible", 0.8, blocks) + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + d2 = compute_block_takeup_for_entities( + "takes_up_aca_if_eligible", 0.8, blocks + ) assert not np.array_equal(d1, d2) def test_hh_salt_differs_from_block_only(self): @@ -307,7 +315,9 @@ class TestGeographyAssignmentCountyFips: """Verify county_fips field on GeographyAssignment.""" def test_county_fips_equals_block_prefix(self): - blocks = np.array(["370010001001001", "480010002002002", "060370003003003"]) + blocks = np.array( + ["370010001001001", "480010002002002", "060370003003003"] + ) ga = GeographyAssignment( block_geoid=blocks, cd_geoid=np.array(["3701", "4801", "0613"]), @@ -340,8 +350,12 @@ class TestBlockTakeupSeeding: def test_reproducible(self): blocks = np.array(["010010001001001"] * 50 + ["020010001001001"] * 50) - r1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) - r2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) + r1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) + r2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks + ) np.testing.assert_array_equal(r1, r2) def test_different_blocks_different_draws(self): @@ -542,13 +556,17 @@ def test_state_specific_rate_resolved_from_block(self): n = 5000 blocks_nc = np.array(["370010001001001"] * n) - result_nc = compute_block_takeup_for_entities(var, rate_dict, blocks_nc) + result_nc = compute_block_takeup_for_entities( + var, rate_dict, blocks_nc + ) # NC rate=0.9, expect ~90% frac_nc = result_nc.mean() assert 0.85 < frac_nc < 0.95, f"NC frac={frac_nc}" blocks_tx = np.array(["480010002002002"] * n) - result_tx = compute_block_takeup_for_entities(var, rate_dict, blocks_tx) + result_tx = compute_block_takeup_for_entities( + var, rate_dict, blocks_tx + ) # TX rate=0.6, expect ~60% frac_tx = result_tx.mean() assert 0.55 < frac_tx < 0.65, f"TX frac={frac_tx}" diff --git a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py index dbc76fb1..c8588b78 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py @@ -130,7 +130,9 @@ def _insert_aca_ptc_data(engine): ] for tid, sid, var, val, period in targets: conn.execute( - text("INSERT INTO targets VALUES (:tid, :sid, :var, :val, :period, 1)"), + text( + "INSERT INTO targets VALUES (:tid, :sid, :var, :val, :period, 1)" + ), { "tid": tid, "sid": sid, @@ -189,7 +191,9 @@ def test_geographic_id_populated(self): df = b._query_targets({"domain_variables": ["aca_ptc"]}) national = df[df["geo_level"] == "national"] self.assertTrue((national["geographic_id"] == "US").all()) - state_ca = df[(df["geo_level"] == "state") & (df["geographic_id"] == "6")] + state_ca = df[ + (df["geo_level"] == "state") & (df["geographic_id"] == "6") + ] self.assertGreater(len(state_ca), 0) @@ -221,9 +225,9 @@ def _get_targets_with_uprating(self, cpi_factor=1.1, pop_factor=1.02): } df["original_value"] = df["value"].copy() df["uprating_factor"] = df.apply( - lambda row: b._get_uprating_info(row["variable"], row["period"], factors)[ - 0 - ], + lambda row: b._get_uprating_info( + row["variable"], row["period"], factors + )[0], axis=1, ) df["value"] = df["original_value"] * df["uprating_factor"] @@ -248,7 +252,9 @@ def test_cd_sums_match_uprated_state(self): & (result["geo_level"] == "district") & ( result["geographic_id"].apply( - lambda g, s=sf: int(g) // 100 == s if g.isdigit() else False + lambda g, s=sf: ( + int(g) // 100 == s if g.isdigit() else False + ) ) ) ] @@ -282,7 +288,8 @@ def test_hif_is_one_when_cds_sum_to_state(self): b, df, factors = self._get_targets_with_uprating(cpi_factor=1.15) result = b._apply_hierarchical_uprating(df, ["aca_ptc"], factors) cd_aca = result[ - (result["variable"] == "aca_ptc") & (result["geo_level"] == "district") + (result["variable"] == "aca_ptc") + & (result["geo_level"] == "district") ] for _, row in cd_aca.iterrows(): self.assertAlmostEqual(row["hif"], 1.0, places=6) @@ -554,14 +561,18 @@ def test_state_fips_set_correctly(self, mock_msim_cls, mock_gcv): ) # First sim should get state 37 - fips_calls_0 = [c for c in sims[0].set_input_calls if c[0] == "state_fips"] + fips_calls_0 = [ + c for c in sims[0].set_input_calls if c[0] == "state_fips" + ] assert len(fips_calls_0) == 1 np.testing.assert_array_equal( fips_calls_0[0][2], np.full(4, 37, dtype=np.int32) ) # Second sim should get state 48 - fips_calls_1 = [c for c in sims[1].set_input_calls if c[0] == "state_fips"] + fips_calls_1 = [ + c for c in sims[1].set_input_calls if c[0] == "state_fips" + ] assert len(fips_calls_1) == 1 np.testing.assert_array_equal( fips_calls_1[0][2], np.full(4, 48, dtype=np.int32) @@ -602,9 +613,9 @@ def test_takeup_vars_forced_true(self, mock_msim_cls, mock_gcv): assert values.all(), f"{var} not forced True" set_true_vars.add(var) - assert takeup_var_names == set_true_vars, ( - f"Missing forced-true vars: {takeup_var_names - set_true_vars}" - ) + assert ( + takeup_var_names == set_true_vars + ), f"Missing forced-true vars: {takeup_var_names - set_true_vars}" # Entity-level calculation happens for affected target entity_calcs = [ @@ -727,7 +738,9 @@ def test_return_structure(self, mock_msim_cls, mock_gcv, mock_county_idx): return_value=["var_a"], ) @patch("policyengine_us.Microsimulation") - def test_sim_reuse_within_state(self, mock_msim_cls, mock_gcv, mock_county_idx): + def test_sim_reuse_within_state( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): sim = _FakeSimulation() mock_msim_cls.return_value = sim @@ -758,7 +771,9 @@ def test_sim_reuse_within_state(self, mock_msim_cls, mock_gcv, mock_county_idx): return_value=[], ) @patch("policyengine_us.Microsimulation") - def test_fresh_sim_across_states(self, mock_msim_cls, mock_gcv, mock_county_idx): + def test_fresh_sim_across_states( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): mock_msim_cls.side_effect = [ _FakeSimulation(), _FakeSimulation(), @@ -787,7 +802,9 @@ def test_fresh_sim_across_states(self, mock_msim_cls, mock_gcv, mock_county_idx) return_value=["var_a", "county"], ) @patch("policyengine_us.Microsimulation") - def test_delete_arrays_per_county(self, mock_msim_cls, mock_gcv, mock_county_idx): + def test_delete_arrays_per_county( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): sim = _FakeSimulation() mock_msim_cls.return_value = sim @@ -862,7 +879,9 @@ def _make_geo(self, states, n_records=4): return_value=[], ) @patch("policyengine_us.Microsimulation") - def test_workers_gt1_creates_pool(self, mock_msim_cls, mock_gcv, mock_pool_cls): + def test_workers_gt1_creates_pool( + self, mock_msim_cls, mock_gcv, mock_pool_cls + ): mock_future = MagicMock() mock_future.result.return_value = ( 37, @@ -993,7 +1012,9 @@ def test_workers_gt1_creates_pool( return_value=[], ) @patch("policyengine_us.Microsimulation") - def test_workers_1_skips_pool(self, mock_msim_cls, mock_gcv, mock_county_idx): + def test_workers_1_skips_pool( + self, mock_msim_cls, mock_gcv, mock_county_idx + ): mock_msim_cls.return_value = _FakeSimulation() builder = self._make_builder() geo = self._make_geo(["37001"]) diff --git a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py index 403fe1af..78ea4723 100644 --- a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py +++ b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py @@ -17,7 +17,9 @@ from policyengine_us_data.storage import STORAGE_FOLDER -DATASET_PATH = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") +DATASET_PATH = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db") DB_URI = f"sqlite:///{DB_PATH}" @@ -99,7 +101,9 @@ def test_xw_matches_stacked_sim(): for i, cd in enumerate(cds_ordered): mask = geography.cd_geoid.astype(str) == cd cd_weights[cd] = w[mask].sum() - top_cds = sorted(cd_weights, key=cd_weights.get, reverse=True)[:N_CDS_TO_CHECK] + top_cds = sorted(cd_weights, key=cd_weights.get, reverse=True)[ + :N_CDS_TO_CHECK + ] check_vars = ["aca_ptc", "snap"] tmpdir = tempfile.mkdtemp() @@ -125,7 +129,8 @@ def test_xw_matches_stacked_sim(): stacked_sum = (vals * hh_weight).sum() cd_row = targets_df[ - (targets_df["variable"] == var) & (targets_df["geographic_id"] == cd) + (targets_df["variable"] == var) + & (targets_df["geographic_id"] == cd) ] if len(cd_row) == 0: continue diff --git a/policyengine_us_data/tests/test_constraint_validation.py b/policyengine_us_data/tests/test_constraint_validation.py index e494f5c9..29920475 100644 --- a/policyengine_us_data/tests/test_constraint_validation.py +++ b/policyengine_us_data/tests/test_constraint_validation.py @@ -138,7 +138,9 @@ def test_conflicting_lower_bounds(self): Constraint(variable="age", operation=">", value="20"), Constraint(variable="age", operation=">=", value="25"), ] - with pytest.raises(ConstraintValidationError, match="conflicting lower bounds"): + with pytest.raises( + ConstraintValidationError, match="conflicting lower bounds" + ): ensure_consistent_constraint_set(constraints) def test_conflicting_upper_bounds(self): @@ -147,7 +149,9 @@ def test_conflicting_upper_bounds(self): Constraint(variable="age", operation="<", value="50"), Constraint(variable="age", operation="<=", value="45"), ] - with pytest.raises(ConstraintValidationError, match="conflicting upper bounds"): + with pytest.raises( + ConstraintValidationError, match="conflicting upper bounds" + ): ensure_consistent_constraint_set(constraints) @@ -189,7 +193,9 @@ class TestNonNumericValues: def test_string_equality_valid(self): """medicaid_enrolled == 'True' should pass.""" constraints = [ - Constraint(variable="medicaid_enrolled", operation="==", value="True"), + Constraint( + variable="medicaid_enrolled", operation="==", value="True" + ), ] ensure_consistent_constraint_set(constraints) # No exception diff --git a/policyengine_us_data/tests/test_database_build.py b/policyengine_us_data/tests/test_database_build.py index 87a6ce08..0bdcdeb7 100644 --- a/policyengine_us_data/tests/test_database_build.py +++ b/policyengine_us_data/tests/test_database_build.py @@ -22,9 +22,7 @@ # HuggingFace URL for the stratified CPS dataset. # ETL scripts use this only to derive the time period (2024). -HF_DATASET = ( - "hf://policyengine/policyengine-us-data/calibration/stratified_extended_cps.h5" -) +HF_DATASET = "hf://policyengine/policyengine-us-data/calibration/stratified_extended_cps.h5" # Scripts run in the same order as `make database` in the Makefile. # create_database_tables.py does not use etl_argparser. @@ -79,7 +77,9 @@ def built_db(): ) if errors: - pytest.fail(f"{len(errors)} ETL script(s) failed:\n" + "\n\n".join(errors)) + pytest.fail( + f"{len(errors)} ETL script(s) failed:\n" + "\n\n".join(errors) + ) assert DB_PATH.exists(), "policy_data.db was not created" return DB_PATH @@ -96,7 +96,9 @@ def test_expected_tables_exist(built_db): conn = sqlite3.connect(str(built_db)) tables = { row[0] - for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + for row in conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) } conn.close() @@ -120,9 +122,9 @@ def test_national_targets_loaded(built_db): variables = {r[0] for r in rows} for expected in ["snap", "social_security", "ssi"]: - assert expected in variables, ( - f"National target '{expected}' missing. Found: {sorted(variables)}" - ) + assert ( + expected in variables + ), f"National target '{expected}' missing. Found: {sorted(variables)}" def test_state_income_tax_targets(built_db): @@ -146,9 +148,9 @@ def test_state_income_tax_targets(built_db): # California should be the largest, over $100B. ca_val = state_totals.get("06") or state_totals.get("6") assert ca_val is not None, "California (FIPS 06) target missing" - assert ca_val > 100e9, ( - f"California income tax should be > $100B, got ${ca_val / 1e9:.1f}B" - ) + assert ( + ca_val > 100e9 + ), f"California income tax should be > $100B, got ${ca_val / 1e9:.1f}B" def test_congressional_district_strata(built_db): @@ -169,7 +171,9 @@ def test_all_target_variables_exist_in_policyengine(built_db): from policyengine_us.system import system conn = sqlite3.connect(str(built_db)) - variables = {r[0] for r in conn.execute("SELECT DISTINCT variable FROM targets")} + variables = { + r[0] for r in conn.execute("SELECT DISTINCT variable FROM targets") + } conn.close() missing = [v for v in variables if v not in system.variables] diff --git a/policyengine_us_data/tests/test_datasets/test_county_fips.py b/policyengine_us_data/tests/test_datasets/test_county_fips.py index ac2eb9fa..0414aa55 100644 --- a/policyengine_us_data/tests/test_datasets/test_county_fips.py +++ b/policyengine_us_data/tests/test_datasets/test_county_fips.py @@ -48,7 +48,9 @@ def mock_upload_to_hf(): def mock_local_folder(): """Mock the LOCAL_FOLDER""" mock_path = MagicMock() - with patch("policyengine_us_data.geography.county_fips.LOCAL_FOLDER", mock_path): + with patch( + "policyengine_us_data.geography.county_fips.LOCAL_FOLDER", mock_path + ): yield mock_path @@ -177,4 +179,6 @@ def test_huggingface_upload(mock_upload_to_hf, mock_to_csv, mock_requests_get): assert call_kwargs["repo_file_path"] == "county_fips_2020.csv.gz" # Verify that the first parameter is a BytesIO instance - assert isinstance(mock_upload_to_hf.call_args[1]["local_file_path"], BytesIO) + assert isinstance( + mock_upload_to_hf.call_args[1]["local_file_path"], BytesIO + ) diff --git a/policyengine_us_data/tests/test_datasets/test_cps.py b/policyengine_us_data/tests/test_datasets/test_cps.py index f0346939..bbfba73b 100644 --- a/policyengine_us_data/tests/test_datasets/test_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_cps.py @@ -13,11 +13,18 @@ def test_cps_has_auto_loan_interest(): RELATIVE_TOLERANCE = 0.4 assert ( - abs(sim.calculate("auto_loan_interest").sum() / AUTO_LOAN_INTEREST_TARGET - 1) + abs( + sim.calculate("auto_loan_interest").sum() + / AUTO_LOAN_INTEREST_TARGET + - 1 + ) < RELATIVE_TOLERANCE ) assert ( - abs(sim.calculate("auto_loan_balance").sum() / AUTO_LOAN_BALANCE_TARGET - 1) + abs( + sim.calculate("auto_loan_balance").sum() / AUTO_LOAN_BALANCE_TARGET + - 1 + ) < RELATIVE_TOLERANCE ) @@ -31,7 +38,11 @@ def test_cps_has_fsla_overtime_premium(): OVERTIME_PREMIUM_TARGET = 70e9 RELATIVE_TOLERANCE = 0.2 assert ( - abs(sim.calculate("fsla_overtime_premium").sum() / OVERTIME_PREMIUM_TARGET - 1) + abs( + sim.calculate("fsla_overtime_premium").sum() + / OVERTIME_PREMIUM_TARGET + - 1 + ) < RELATIVE_TOLERANCE ) diff --git a/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py b/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py index 4e8732b0..8314fe7f 100644 --- a/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py +++ b/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py @@ -41,23 +41,27 @@ def test_ecps_employment_income_positive(ecps_sim): def test_ecps_self_employment_income_positive(ecps_sim): total = ecps_sim.calculate("self_employment_income").sum() - assert total > 50e9, f"self_employment_income sum is {total:.2e}, expected > 50B." + assert ( + total > 50e9 + ), f"self_employment_income sum is {total:.2e}, expected > 50B." def test_ecps_household_count(ecps_sim): """Household count should be roughly 130-160M.""" total_hh = ecps_sim.calculate("household_weight").values.sum() - assert 100e6 < total_hh < 200e6, ( - f"Total households = {total_hh:.2e}, expected 100M-200M." - ) + assert ( + 100e6 < total_hh < 200e6 + ), f"Total households = {total_hh:.2e}, expected 100M-200M." def test_ecps_person_count(ecps_sim): """Weighted person count should be roughly 330M.""" - total_people = ecps_sim.calculate("household_weight", map_to="person").values.sum() - assert 250e6 < total_people < 400e6, ( - f"Total people = {total_people:.2e}, expected 250M-400M." - ) + total_people = ecps_sim.calculate( + "household_weight", map_to="person" + ).values.sum() + assert ( + 250e6 < total_people < 400e6 + ), f"Total people = {total_people:.2e}, expected 250M-400M." def test_ecps_poverty_rate_reasonable(ecps_sim): @@ -80,9 +84,9 @@ def test_ecps_mean_employment_income_reasonable(ecps_sim): """Mean employment income per person should be $20k-$60k.""" income = ecps_sim.calculate("employment_income", map_to="person") mean = income.mean() - assert 15_000 < mean < 80_000, ( - f"Mean employment income = ${mean:,.0f}, expected $15k-$80k." - ) + assert ( + 15_000 < mean < 80_000 + ), f"Mean employment income = ${mean:,.0f}, expected $15k-$80k." # ── CPS sanity checks ─────────────────────────────────────────── @@ -90,7 +94,9 @@ def test_ecps_mean_employment_income_reasonable(ecps_sim): def test_cps_employment_income_positive(cps_sim): total = cps_sim.calculate("employment_income").sum() - assert total > 5e12, f"CPS employment_income sum is {total:.2e}, expected > 5T." + assert ( + total > 5e12 + ), f"CPS employment_income sum is {total:.2e}, expected > 5T." def test_cps_household_count(cps_sim): @@ -116,20 +122,24 @@ def sparse_sim(): def test_sparse_employment_income_positive(sparse_sim): """Sparse dataset employment income must be in the trillions.""" total = sparse_sim.calculate("employment_income").sum() - assert total > 5e12, f"Sparse employment_income sum is {total:.2e}, expected > 5T." + assert ( + total > 5e12 + ), f"Sparse employment_income sum is {total:.2e}, expected > 5T." def test_sparse_household_count(sparse_sim): total_hh = sparse_sim.calculate("household_weight").values.sum() - assert 100e6 < total_hh < 200e6, ( - f"Sparse total households = {total_hh:.2e}, expected 100M-200M." - ) + assert ( + 100e6 < total_hh < 200e6 + ), f"Sparse total households = {total_hh:.2e}, expected 100M-200M." def test_sparse_poverty_rate_reasonable(sparse_sim): in_poverty = sparse_sim.calculate("person_in_poverty", map_to="person") rate = in_poverty.mean() - assert 0.05 < rate < 0.30, f"Sparse poverty rate = {rate:.1%}, expected 5-30%." + assert ( + 0.05 < rate < 0.30 + ), f"Sparse poverty rate = {rate:.1%}, expected 5-30%." # ── File size checks ─────────────────────────────────────────── @@ -143,6 +153,6 @@ def test_ecps_file_size(): if not path.exists(): pytest.skip("enhanced_cps_2024.h5 not found") size_mb = path.stat().st_size / (1024 * 1024) - assert size_mb > 100, ( - f"enhanced_cps_2024.h5 is only {size_mb:.1f}MB, expected >100MB" - ) + assert ( + size_mb > 100 + ), f"enhanced_cps_2024.h5 is only {size_mb:.1f}MB, expected >100MB" diff --git a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py index 298de5a4..4c79874e 100644 --- a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py @@ -50,10 +50,10 @@ def test_ecps_replicates_jct_tax_expenditures(): & (calibration_log["epoch"] == calibration_log["epoch"].max()) ] - assert jct_rows.rel_abs_error.max() < 0.5, ( - "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( - jct_rows.rel_abs_error.max() - ) + assert ( + jct_rows.rel_abs_error.max() < 0.5 + ), "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( + jct_rows.rel_abs_error.max() ) @@ -71,7 +71,9 @@ def deprecated_test_ecps_replicates_jct_tax_expenditures_full(): } baseline = Microsimulation(dataset=EnhancedCPS_2024) - income_tax_b = baseline.calculate("income_tax", period=2024, map_to="household") + income_tax_b = baseline.calculate( + "income_tax", period=2024, map_to="household" + ) for deduction, target in EXPENDITURE_TARGETS.items(): # Create reform that neutralizes the deduction @@ -80,8 +82,12 @@ def apply(self): self.neutralize_variable(deduction) # Run reform simulation - reformed = Microsimulation(reform=RepealDeduction, dataset=EnhancedCPS_2024) - income_tax_r = reformed.calculate("income_tax", period=2024, map_to="household") + reformed = Microsimulation( + reform=RepealDeduction, dataset=EnhancedCPS_2024 + ) + income_tax_r = reformed.calculate( + "income_tax", period=2024, map_to="household" + ) # Calculate tax expenditure tax_expenditure = (income_tax_r - income_tax_b).sum() @@ -131,9 +137,9 @@ def test_undocumented_matches_ssn_none(): # 1. Per-person equivalence mismatches = np.where(ssn_type_none_mask != undocumented_mask)[0] - assert mismatches.size == 0, ( - f"{mismatches.size} mismatches between 'NONE' SSN and 'UNDOCUMENTED' status" - ) + assert ( + mismatches.size == 0 + ), f"{mismatches.size} mismatches between 'NONE' SSN and 'UNDOCUMENTED' status" # 2. Optional aggregate sanity-check count = undocumented_mask.sum() @@ -158,7 +164,9 @@ def test_aca_calibration(): # Monthly to yearly targets["spending"] = targets["spending"] * 12 # Adjust to match national target - targets["spending"] = targets["spending"] * (98e9 / targets["spending"].sum()) + targets["spending"] = targets["spending"] * ( + 98e9 / targets["spending"].sum() + ) sim = Microsimulation(dataset=EnhancedCPS_2024) state_code_hh = sim.calculate("state_code", map_to="household").values @@ -181,7 +189,9 @@ def test_aca_calibration(): if pct_error > TOLERANCE: failed = True - assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert ( + not failed + ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." def test_immigration_status_diversity(): @@ -217,17 +227,19 @@ def test_immigration_status_diversity(): ) # Also check that we have a reasonable percentage of citizens (should be 85-90%) - assert 80 < citizen_pct < 95, ( - f"Citizen percentage ({citizen_pct:.1f}%) outside expected range (80-95%)" - ) + assert ( + 80 < citizen_pct < 95 + ), f"Citizen percentage ({citizen_pct:.1f}%) outside expected range (80-95%)" # Check that we have some non-citizens non_citizen_pct = 100 - citizen_pct - assert non_citizen_pct > 5, ( - f"Too few non-citizens ({non_citizen_pct:.1f}%) - expected at least 5%" - ) + assert ( + non_citizen_pct > 5 + ), f"Too few non-citizens ({non_citizen_pct:.1f}%) - expected at least 5%" - print(f"Immigration status diversity test passed: {citizen_pct:.1f}% citizens") + print( + f"Immigration status diversity test passed: {citizen_pct:.1f}% citizens" + ) def test_medicaid_calibration(): @@ -265,4 +277,6 @@ def test_medicaid_calibration(): if pct_error > TOLERANCE: failed = True - assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert ( + not failed + ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." diff --git a/policyengine_us_data/tests/test_datasets/test_sipp_assets.py b/policyengine_us_data/tests/test_datasets/test_sipp_assets.py index a79b4bce..0f839a9c 100644 --- a/policyengine_us_data/tests/test_datasets/test_sipp_assets.py +++ b/policyengine_us_data/tests/test_datasets/test_sipp_assets.py @@ -101,12 +101,12 @@ def test_liquid_assets_distribution(): MEDIAN_MIN = 3_000 MEDIAN_MAX = 20_000 - assert weighted_median > MEDIAN_MIN, ( - f"Median liquid assets ${weighted_median:,.0f} below minimum ${MEDIAN_MIN:,}" - ) - assert weighted_median < MEDIAN_MAX, ( - f"Median liquid assets ${weighted_median:,.0f} above maximum ${MEDIAN_MAX:,}" - ) + assert ( + weighted_median > MEDIAN_MIN + ), f"Median liquid assets ${weighted_median:,.0f} below minimum ${MEDIAN_MIN:,}" + assert ( + weighted_median < MEDIAN_MAX + ), f"Median liquid assets ${weighted_median:,.0f} above maximum ${MEDIAN_MAX:,}" def test_asset_categories_exist(): @@ -127,7 +127,9 @@ def test_asset_categories_exist(): assert bonds >= 0, "Bond assets should be non-negative" # Bank accounts typically largest category of liquid assets - assert bank > stocks * 0.3, "Bank accounts should be substantial relative to stocks" + assert ( + bank > stocks * 0.3 + ), "Bank accounts should be substantial relative to stocks" def test_low_asset_households(): @@ -153,9 +155,9 @@ def test_low_asset_households(): MIN_PCT = 0.10 MAX_PCT = 0.70 - assert below_2k > MIN_PCT, ( - f"Only {below_2k:.1%} have <$2k liquid assets, expected at least {MIN_PCT:.0%}" - ) - assert below_2k < MAX_PCT, ( - f"{below_2k:.1%} have <$2k liquid assets, expected at most {MAX_PCT:.0%}" - ) + assert ( + below_2k > MIN_PCT + ), f"Only {below_2k:.1%} have <$2k liquid assets, expected at least {MIN_PCT:.0%}" + assert ( + below_2k < MAX_PCT + ), f"{below_2k:.1%} have <$2k liquid assets, expected at most {MAX_PCT:.0%}" diff --git a/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py index 9316d390..23b7b2dc 100644 --- a/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py @@ -19,10 +19,12 @@ def test_small_ecps_loads(year: int): # Employment income should be positive (not zero from missing vars) emp_income = sim.calculate("employment_income", 2025).sum() - assert emp_income > 0, ( - f"Small ECPS employment_income sum is {emp_income}, expected > 0." - ) + assert ( + emp_income > 0 + ), f"Small ECPS employment_income sum is {emp_income}, expected > 0." # Should have a reasonable number of households hh_count = len(sim.calculate("household_net_income", 2025)) - assert hh_count > 100, f"Small ECPS has only {hh_count} households, expected > 100." + assert ( + hh_count > 100 + ), f"Small ECPS has only {hh_count} households, expected > 100." diff --git a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py index a7ee941b..bea1e3b3 100644 --- a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py @@ -115,10 +115,10 @@ def test_sparse_ecps_replicates_jct_tax_expenditures(): & (calibration_log["epoch"] == calibration_log["epoch"].max()) ] - assert jct_rows.rel_abs_error.max() < 0.5, ( - "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( - jct_rows.rel_abs_error.max() - ) + assert ( + jct_rows.rel_abs_error.max() < 0.5 + ), "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( + jct_rows.rel_abs_error.max() ) @@ -133,7 +133,9 @@ def deprecated_test_sparse_ecps_replicates_jct_tax_expenditures_full(sim): } baseline = sim - income_tax_b = baseline.calculate("income_tax", period=2024, map_to="household") + income_tax_b = baseline.calculate( + "income_tax", period=2024, map_to="household" + ) for deduction, target in EXPENDITURE_TARGETS.items(): # Create reform that neutralizes the deduction @@ -143,7 +145,9 @@ def apply(self): # Run reform simulation reformed = Microsimulation(reform=RepealDeduction, dataset=sim.dataset) - income_tax_r = reformed.calculate("income_tax", period=2024, map_to="household") + income_tax_r = reformed.calculate( + "income_tax", period=2024, map_to="household" + ) # Calculate tax expenditure tax_expenditure = (income_tax_r - income_tax_b).sum() @@ -184,7 +188,9 @@ def test_sparse_aca_calibration(sim): # Monthly to yearly targets["spending"] = targets["spending"] * 12 # Adjust to match national target - targets["spending"] = targets["spending"] * (98e9 / targets["spending"].sum()) + targets["spending"] = targets["spending"] * ( + 98e9 / targets["spending"].sum() + ) state_code_hh = sim.calculate("state_code", map_to="household").values aca_ptc = sim.calculate("aca_ptc", map_to="household", period=2025) @@ -206,7 +212,9 @@ def test_sparse_aca_calibration(sim): if pct_error > TOLERANCE: failed = True - assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert ( + not failed + ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." def test_sparse_medicaid_calibration(sim): @@ -238,4 +246,6 @@ def test_sparse_medicaid_calibration(sim): if pct_error > TOLERANCE: failed = True - assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert ( + not failed + ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py index 722064b5..d5091b8b 100644 --- a/policyengine_us_data/tests/test_format_comparison.py +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -87,7 +87,10 @@ def _read_h5py_arrays(h5py_path: str): arr = f[var][period_key][:] if arr.dtype.kind in ("S", "O"): arr = np.array( - [x.decode() if isinstance(x, bytes) else str(x) for x in arr] + [ + x.decode() if isinstance(x, bytes) else str(x) + for x in arr + ] ) # Wrap in nested dict keyed by the period string data[var] = {period_key: arr} @@ -114,7 +117,9 @@ def h5py_to_hdfstore(h5py_path: str, hdfstore_path: str) -> dict: print("Reading h5py file...") data, time_period, h5_vars = _read_h5py_arrays(h5py_path) n_persons = len(next(iter(data.get("person_id", {}).values()), [])) - print(f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}") + print( + f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}" + ) print("Splitting into entity DataFrames...") entity_dfs = split_data_into_entity_dfs(data, system, time_period) @@ -182,7 +187,11 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: hdf_unique = np.unique(hdf_values) if h5_values.dtype.kind in ("U", "S", "O"): match = set( - (x.decode() if isinstance(x, bytes) else str(x)) + ( + x.decode() + if isinstance(x, bytes) + else str(x) + ) for x in h5_unique ) == set(str(x) for x in hdf_unique) else: @@ -208,11 +217,17 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: if h5_values.dtype.kind in ("U", "S", "O"): h5_str = np.array( [ - (x.decode() if isinstance(x, bytes) else str(x)) + ( + x.decode() + if isinstance(x, bytes) + else str(x) + ) for x in h5_values ] ) - hdf_str = np.array([str(x) for x in hdf_values]) + hdf_str = np.array( + [str(x) for x in hdf_values] + ) if np.array_equal(h5_str, hdf_str): passed.append(var) else: @@ -313,12 +328,12 @@ def test_roundtrip(h5py_path, tmp_path): result = compare_formats(h5py_path, hdfstore_path) print_results(result) - assert len(result["failed"]) == 0, ( - f"{len(result['failed'])} variables have mismatched values" - ) - assert len(result["skipped"]) == 0, ( - f"{len(result['skipped'])} variables missing from HDFStore" - ) + assert ( + len(result["failed"]) == 0 + ), f"{len(result['failed'])} variables have mismatched values" + assert ( + len(result["skipped"]) == 0 + ), f"{len(result['skipped'])} variables missing from HDFStore" def test_manifest(h5py_path, tmp_path): @@ -327,7 +342,9 @@ def test_manifest(h5py_path, tmp_path): h5py_to_hdfstore(h5py_path, hdfstore_path) with pd.HDFStore(hdfstore_path, "r") as store: - assert "/_variable_metadata" in store.keys(), "Missing _variable_metadata table" + assert ( + "/_variable_metadata" in store.keys() + ), "Missing _variable_metadata table" manifest = store["/_variable_metadata"] assert "variable" in manifest.columns assert "entity" in manifest.columns @@ -346,15 +363,17 @@ def test_all_entities(h5py_path, tmp_path): expected = set(ENTITIES) with pd.HDFStore(hdfstore_path, "r") as store: - actual = {k.lstrip("/") for k in store.keys() if not k.startswith("/_")} + actual = { + k.lstrip("/") for k in store.keys() if not k.startswith("/_") + } missing = expected - actual assert not missing, f"Missing entity tables: {missing}" for entity in expected: df = store[f"/{entity}"] assert len(df) > 0, f"Entity {entity} has 0 rows" - assert f"{entity}_id" in df.columns, ( - f"Entity {entity} missing {entity}_id column" - ) + assert ( + f"{entity}_id" in df.columns + ), f"Entity {entity} missing {entity}_id column" print(f" {entity}: {len(df):,} rows, {len(df.columns)} cols") @@ -365,7 +384,9 @@ def test_all_entities(h5py_path, tmp_path): parser = argparse.ArgumentParser( description="Convert h5py dataset to HDFStore and verify roundtrip" ) - parser.add_argument("--h5py-path", required=True, help="Path to h5py format file") + parser.add_argument( + "--h5py-path", required=True, help="Path to h5py format file" + ) parser.add_argument( "--output-path", default=None, diff --git a/policyengine_us_data/tests/test_puf_impute.py b/policyengine_us_data/tests/test_puf_impute.py index d968fb16..fcdcf763 100644 --- a/policyengine_us_data/tests/test_puf_impute.py +++ b/policyengine_us_data/tests/test_puf_impute.py @@ -57,7 +57,9 @@ def _make_data( if age is not None: data["age"] = {tp: np.concatenate([age, age]).astype(np.float32)} if is_male is not None: - data["is_male"] = {tp: np.concatenate([is_male, is_male]).astype(np.float32)} + data["is_male"] = { + tp: np.concatenate([is_male, is_male]).astype(np.float32) + } return data, n, tp diff --git a/policyengine_us_data/tests/test_schema_views_and_lookups.py b/policyengine_us_data/tests/test_schema_views_and_lookups.py index c8e5f4f8..d7495ff3 100644 --- a/policyengine_us_data/tests/test_schema_views_and_lookups.py +++ b/policyengine_us_data/tests/test_schema_views_and_lookups.py @@ -227,7 +227,9 @@ def _query_stratum_domain(self): from sqlalchemy import text with self.engine.connect() as conn: - rows = conn.execute(text("SELECT * FROM stratum_domain")).fetchall() + rows = conn.execute( + text("SELECT * FROM stratum_domain") + ).fetchall() return rows def test_geographic_stratum_excluded(self): @@ -289,14 +291,18 @@ def _query_target_overview(self): from sqlalchemy import text with self.engine.connect() as conn: - rows = conn.execute(text("SELECT * FROM target_overview")).fetchall() + rows = conn.execute( + text("SELECT * FROM target_overview") + ).fetchall() return rows def _overview_columns(self): from sqlalchemy import text with self.engine.connect() as conn: - cursor = conn.execute(text("SELECT * FROM target_overview LIMIT 0")) + cursor = conn.execute( + text("SELECT * FROM target_overview LIMIT 0") + ) return [desc[0] for desc in cursor.cursor.description] def test_national_geo_level(self): diff --git a/policyengine_us_data/utils/census.py b/policyengine_us_data/utils/census.py index 422d750c..c61cc166 100644 --- a/policyengine_us_data/utils/census.py +++ b/policyengine_us_data/utils/census.py @@ -139,7 +139,9 @@ def get_census_docs(year): - docs_url = f"https://api.census.gov/data/{year}/acs/acs1/subject/variables.json" + docs_url = ( + f"https://api.census.gov/data/{year}/acs/acs1/subject/variables.json" + ) cache_file = f"census_docs_{year}.json" if is_cached(cache_file): logger.info(f"Using cached {cache_file}") diff --git a/policyengine_us_data/utils/constraint_validation.py b/policyengine_us_data/utils/constraint_validation.py index f533739c..d3c4305d 100644 --- a/policyengine_us_data/utils/constraint_validation.py +++ b/policyengine_us_data/utils/constraint_validation.py @@ -111,7 +111,9 @@ def _check_operation_compatibility(var_name: str, operations: set) -> None: ) -def _check_range_validity(var_name: str, constraints: List[Constraint]) -> None: +def _check_range_validity( + var_name: str, constraints: List[Constraint] +) -> None: """Check that range constraints don't create an empty range.""" lower_bound = float("-inf") upper_bound = float("inf") @@ -126,7 +128,9 @@ def _check_range_validity(var_name: str, constraints: List[Constraint]) -> None: continue if c.operation == ">": - if val > lower_bound or (val == lower_bound and not lower_inclusive): + if val > lower_bound or ( + val == lower_bound and not lower_inclusive + ): lower_bound = val lower_inclusive = False elif c.operation == ">=": @@ -134,7 +138,9 @@ def _check_range_validity(var_name: str, constraints: List[Constraint]) -> None: lower_bound = val lower_inclusive = True elif c.operation == "<": - if val < upper_bound or (val == upper_bound and not upper_inclusive): + if val < upper_bound or ( + val == upper_bound and not upper_inclusive + ): upper_bound = val upper_inclusive = False elif c.operation == "<=": @@ -148,7 +154,9 @@ def _check_range_validity(var_name: str, constraints: List[Constraint]) -> None: f"{var_name}: empty range - lower bound {lower_bound} > " f"upper bound {upper_bound}" ) - if lower_bound == upper_bound and not (lower_inclusive and upper_inclusive): + if lower_bound == upper_bound and not ( + lower_inclusive and upper_inclusive + ): raise ConstraintValidationError( f"{var_name}: empty range - bounds equal at {lower_bound} " "but not both inclusive" diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index c8a50036..e9509837 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -117,14 +117,18 @@ def upload_files_to_gcs( Upload files to Google Cloud Storage and set metadata with the version. """ credentials, project_id = google.auth.default() - storage_client = storage.Client(credentials=credentials, project=project_id) + storage_client = storage.Client( + credentials=credentials, project=project_id + ) bucket = storage_client.bucket(gcs_bucket_name) for file_path in files: file_path = Path(file_path) blob = bucket.blob(file_path.name) blob.upload_from_filename(file_path) - logging.info(f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}.") + logging.info( + f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}." + ) # Set metadata blob.metadata = {"version": version} @@ -161,7 +165,9 @@ def upload_local_area_file( # Upload to GCS with subdirectory credentials, project_id = google.auth.default() - storage_client = storage.Client(credentials=credentials, project=project_id) + storage_client = storage.Client( + credentials=credentials, project=project_id + ) bucket = storage_client.bucket(gcs_bucket_name) blob_name = f"{subdirectory}/{file_path.name}" @@ -331,7 +337,9 @@ def upload_to_staging_hf( f"Uploaded batch {i // batch_size + 1}: {len(operations)} files to staging/" ) - logging.info(f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace") + logging.info( + f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace" + ) return total_uploaded @@ -482,7 +490,9 @@ def upload_from_hf_staging_to_gcs( token = os.environ.get("HUGGING_FACE_TOKEN") credentials, project_id = google.auth.default() - storage_client = storage.Client(credentials=credentials, project=project_id) + storage_client = storage.Client( + credentials=credentials, project=project_id + ) bucket = storage_client.bucket(gcs_bucket_name) uploaded = 0 diff --git a/policyengine_us_data/utils/db.py b/policyengine_us_data/utils/db.py index 128dbb78..ad0c0669 100644 --- a/policyengine_us_data/utils/db.py +++ b/policyengine_us_data/utils/db.py @@ -11,7 +11,9 @@ ) from policyengine_us_data.storage import STORAGE_FOLDER -DEFAULT_DATASET = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") +DEFAULT_DATASET = str( + STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" +) def etl_argparser( @@ -44,7 +46,10 @@ def etl_argparser( args = parser.parse_args() - if not args.dataset.startswith("hf://") and not Path(args.dataset).exists(): + if ( + not args.dataset.startswith("hf://") + and not Path(args.dataset).exists() + ): raise FileNotFoundError( f"Dataset not found: {args.dataset}\n" f"Either build it locally (`make data`) or pass a " @@ -66,14 +71,18 @@ def get_stratum_by_id(session: Session, stratum_id: int) -> Optional[Stratum]: return session.get(Stratum, stratum_id) -def get_simple_stratum_by_ucgid(session: Session, ucgid: str) -> Optional[Stratum]: +def get_simple_stratum_by_ucgid( + session: Session, ucgid: str +) -> Optional[Stratum]: """ Finds a stratum defined *only* by a single ucgid_str constraint. """ constraint_count_subquery = ( select( StratumConstraint.stratum_id, - sa.func.count(StratumConstraint.stratum_id).label("constraint_count"), + sa.func.count(StratumConstraint.stratum_id).label( + "constraint_count" + ), ) .group_by(StratumConstraint.stratum_id) .subquery() @@ -130,12 +139,16 @@ def parse_ucgid(ucgid_str: str) -> Dict: elif ucgid_str.startswith("0400000US"): state_fips = int(ucgid_str[9:]) return {"type": "state", "state_fips": state_fips} - elif ucgid_str.startswith("5001800US") or ucgid_str.startswith("5001900US"): + elif ucgid_str.startswith("5001800US") or ucgid_str.startswith( + "5001900US" + ): # 5001800US = 118th Congress, 5001900US = 119th Congress state_and_district = ucgid_str[9:] state_fips = int(state_and_district[:2]) district_number = int(state_and_district[2:]) - if district_number == 0 or (state_fips == 11 and district_number == 98): + if district_number == 0 or ( + state_fips == 11 and district_number == 98 + ): district_number = 1 cd_geoid = state_fips * 100 + district_number return { @@ -190,7 +203,9 @@ def get_geographic_strata(session: Session) -> Dict: if not constraints: strata_map["national"] = stratum.stratum_id else: - constraint_vars = {c.constraint_variable: c.value for c in constraints} + constraint_vars = { + c.constraint_variable: c.value for c in constraints + } if "congressional_district_geoid" in constraint_vars: cd_geoid = int(constraint_vars["congressional_district_geoid"]) diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index 9b1e48cb..7a090d25 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -10,7 +10,9 @@ ) -def download(repo: str, repo_filename: str, local_folder: str, version: str = None): +def download( + repo: str, repo_filename: str, local_folder: str, version: str = None +): hf_hub_download( repo_id=repo, @@ -216,11 +218,15 @@ def upload_calibration_artifacts( if log_dir: # Upload run config to calibration/ root for artifact validation - run_config_local = os.path.join(log_dir, f"{prefix}unified_run_config.json") + run_config_local = os.path.join( + log_dir, f"{prefix}unified_run_config.json" + ) if os.path.exists(run_config_local): operations.append( CommitOperationAdd( - path_in_repo=(f"calibration/{prefix}unified_run_config.json"), + path_in_repo=( + f"calibration/{prefix}unified_run_config.json" + ), path_or_fileobj=run_config_local, ) ) diff --git a/policyengine_us_data/utils/loss.py b/policyengine_us_data/utils/loss.py index bfbf49db..51be118b 100644 --- a/policyengine_us_data/utils/loss.py +++ b/policyengine_us_data/utils/loss.py @@ -166,7 +166,9 @@ def build_loss_matrix(dataset: type, time_period): continue mask = ( - (agi >= row["AGI lower bound"]) * (agi < row["AGI upper bound"]) * filer + (agi >= row["AGI lower bound"]) + * (agi < row["AGI upper bound"]) + * filer ) > 0 if row["Filing status"] == "Single": @@ -186,8 +188,12 @@ def build_loss_matrix(dataset: type, time_period): if row["Count"]: values = (values > 0).astype(float) - agi_range_label = f"{fmt(row['AGI lower bound'])}-{fmt(row['AGI upper bound'])}" - taxable_label = "taxable" if row["Taxable only"] else "all" + " returns" + agi_range_label = ( + f"{fmt(row['AGI lower bound'])}-{fmt(row['AGI upper bound'])}" + ) + taxable_label = ( + "taxable" if row["Taxable only"] else "all" + " returns" + ) filing_status_label = row["Filing status"] variable_label = row["Variable"].replace("_", " ") @@ -266,7 +272,9 @@ def build_loss_matrix(dataset: type, time_period): for variable_name in CBO_PROGRAMS: label = f"nation/cbo/{variable_name}" - loss_matrix[label] = sim.calculate(variable_name, map_to="household").values + loss_matrix[label] = sim.calculate( + variable_name, map_to="household" + ).values if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") param_name = CBO_PARAM_NAME_MAP.get(variable_name, variable_name) @@ -306,9 +314,9 @@ def build_loss_matrix(dataset: type, time_period): # National ACA Enrollment (people receiving a PTC) label = "nation/gov/aca_enrollment" - on_ptc = (sim.calculate("aca_ptc", map_to="person", period=2025).values > 0).astype( - int - ) + on_ptc = ( + sim.calculate("aca_ptc", map_to="person", period=2025).values > 0 + ).astype(int) loss_matrix[label] = sim.map_result(on_ptc, "person", "household") ACA_PTC_ENROLLMENT_2024 = 19_743_689 # people enrolled @@ -340,9 +348,13 @@ def build_loss_matrix(dataset: type, time_period): eitc_eligible_children = sim.calculate("eitc_child_count").values eitc = sim.calculate("eitc").values if row["count_children"] < 2: - meets_child_criteria = eitc_eligible_children == row["count_children"] + meets_child_criteria = ( + eitc_eligible_children == row["count_children"] + ) else: - meets_child_criteria = eitc_eligible_children >= row["count_children"] + meets_child_criteria = ( + eitc_eligible_children >= row["count_children"] + ) loss_matrix[returns_label] = sim.map_result( (eitc > 0) * meets_child_criteria, "tax_unit", @@ -396,7 +408,9 @@ def build_loss_matrix(dataset: type, time_period): # Hard-coded totals for variable_name, target in HARD_CODED_TOTALS.items(): label = f"nation/census/{variable_name}" - loss_matrix[label] = sim.calculate(variable_name, map_to="household").values + loss_matrix[label] = sim.calculate( + variable_name, map_to="household" + ).values if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") targets_array.append(target) @@ -404,8 +418,8 @@ def build_loss_matrix(dataset: type, time_period): # Negative household market income total rough estimate from the IRS SOI PUF market_income = sim.calculate("household_market_income").values - loss_matrix["nation/irs/negative_household_market_income_total"] = market_income * ( - market_income < 0 + loss_matrix["nation/irs/negative_household_market_income_total"] = ( + market_income * (market_income < 0) ) targets_array.append(-138e9) @@ -436,27 +450,39 @@ def build_loss_matrix(dataset: type, time_period): # AGI by SPM threshold totals - spm_threshold_agi = pd.read_csv(CALIBRATION_FOLDER / "spm_threshold_agi.csv") + spm_threshold_agi = pd.read_csv( + CALIBRATION_FOLDER / "spm_threshold_agi.csv" + ) for _, row in spm_threshold_agi.iterrows(): - spm_unit_agi = sim.calculate("adjusted_gross_income", map_to="spm_unit").values + spm_unit_agi = sim.calculate( + "adjusted_gross_income", map_to="spm_unit" + ).values spm_threshold = sim.calculate("spm_unit_spm_threshold").values in_threshold_range = (spm_threshold >= row["lower_spm_threshold"]) * ( spm_threshold < row["upper_spm_threshold"] ) - label = f"nation/census/agi_in_spm_threshold_decile_{int(row['decile'])}" + label = ( + f"nation/census/agi_in_spm_threshold_decile_{int(row['decile'])}" + ) loss_matrix[label] = sim.map_result( in_threshold_range * spm_unit_agi, "spm_unit", "household" ) targets_array.append(row["adjusted_gross_income"]) - label = f"nation/census/count_in_spm_threshold_decile_{int(row['decile'])}" - loss_matrix[label] = sim.map_result(in_threshold_range, "spm_unit", "household") + label = ( + f"nation/census/count_in_spm_threshold_decile_{int(row['decile'])}" + ) + loss_matrix[label] = sim.map_result( + in_threshold_range, "spm_unit", "household" + ) targets_array.append(row["count"]) # Population by state and population under 5 by state - state_population = pd.read_csv(CALIBRATION_FOLDER / "population_by_state.csv") + state_population = pd.read_csv( + CALIBRATION_FOLDER / "population_by_state.csv" + ) for _, row in state_population.iterrows(): in_state = sim.calculate("state_code", map_to="person") == row["state"] @@ -467,7 +493,9 @@ def build_loss_matrix(dataset: type, time_period): under_5 = sim.calculate("age").values < 5 in_state_under_5 = in_state * under_5 label = f"state/census/population_under_5_by_state/{row['state']}" - loss_matrix[label] = sim.map_result(in_state_under_5, "person", "household") + loss_matrix[label] = sim.map_result( + in_state_under_5, "person", "household" + ) targets_array.append(row["population_under_5"]) age = sim.calculate("age").values @@ -491,7 +519,9 @@ def build_loss_matrix(dataset: type, time_period): # SALT tax expenditure targeting - _add_tax_expenditure_targets(dataset, time_period, sim, loss_matrix, targets_array) + _add_tax_expenditure_targets( + dataset, time_period, sim, loss_matrix, targets_array + ) if any(loss_matrix.isna().sum() > 0): raise ValueError("Some targets are missing from the loss matrix") @@ -505,7 +535,9 @@ def build_loss_matrix(dataset: type, time_period): # Overall count by SSN card type label = f"nation/ssa/ssn_card_type_{card_type_str.lower()}_count" - loss_matrix[label] = sim.map_result(ssn_type_mask, "person", "household") + loss_matrix[label] = sim.map_result( + ssn_type_mask, "person", "household" + ) # Target undocumented population by year based on various sources if card_type_str == "NONE": @@ -541,11 +573,14 @@ def build_loss_matrix(dataset: type, time_period): for _, row in spending_by_state.iterrows(): # Households located in this state in_state = ( - sim.calculate("state_code", map_to="household").values == row["state"] + sim.calculate("state_code", map_to="household").values + == row["state"] ) # ACA PTC amounts for every household (2025) - aca_value = sim.calculate("aca_ptc", map_to="household", period=2025).values + aca_value = sim.calculate( + "aca_ptc", map_to="household", period=2025 + ).values # Add a loss-matrix entry and matching target label = f"nation/irs/aca_spending/{row['state'].lower()}" @@ -578,7 +613,9 @@ def build_loss_matrix(dataset: type, time_period): in_state_enrolled = in_state & is_enrolled label = f"state/irs/aca_enrollment/{row['state'].lower()}" - loss_matrix[label] = sim.map_result(in_state_enrolled, "person", "household") + loss_matrix[label] = sim.map_result( + in_state_enrolled, "person", "household" + ) if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") @@ -595,7 +632,9 @@ def build_loss_matrix(dataset: type, time_period): state_person = sim.calculate("state_code", map_to="person").values # Flag people in households that actually receive medicaid - has_medicaid = sim.calculate("medicaid_enrolled", map_to="person", period=2025) + has_medicaid = sim.calculate( + "medicaid_enrolled", map_to="person", period=2025 + ) is_medicaid_eligible = sim.calculate( "is_medicaid_eligible", map_to="person", period=2025 ).values @@ -607,7 +646,9 @@ def build_loss_matrix(dataset: type, time_period): in_state_enrolled = in_state & is_enrolled label = f"irs/medicaid_enrollment/{row['state'].lower()}" - loss_matrix[label] = sim.map_result(in_state_enrolled, "person", "household") + loss_matrix[label] = sim.map_result( + in_state_enrolled, "person", "household" + ) if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") @@ -631,7 +672,9 @@ def build_loss_matrix(dataset: type, time_period): age_lower_bound = int(age_range.replace("+", "")) age_upper_bound = np.inf else: - age_lower_bound, age_upper_bound = map(int, age_range.split("-")) + age_lower_bound, age_upper_bound = map( + int, age_range.split("-") + ) age_mask = (age >= age_lower_bound) & (age <= age_upper_bound) label = f"state/census/age/{state}/{age_range}" @@ -702,7 +745,9 @@ def apply(self): simulation.default_calculation_period = time_period # Calculate the baseline and reform income tax values. - income_tax_r = simulation.calculate("income_tax", map_to="household").values + income_tax_r = simulation.calculate( + "income_tax", map_to="household" + ).values # Compute the tax expenditure (TE) values. te_values = income_tax_r - income_tax_b @@ -736,7 +781,9 @@ def _add_agi_state_targets(): + soi_targets["VARIABLE"] + "/" + soi_targets.apply( - lambda r: get_agi_band_label(r["AGI_LOWER_BOUND"], r["AGI_UPPER_BOUND"]), + lambda r: get_agi_band_label( + r["AGI_LOWER_BOUND"], r["AGI_UPPER_BOUND"] + ), axis=1, ) ) @@ -757,7 +804,9 @@ def _add_agi_metric_columns( agi = sim.calculate("adjusted_gross_income").values state = sim.calculate("state_code", map_to="person").values - state = sim.map_result(state, "person", "tax_unit", how="value_from_first_person") + state = sim.map_result( + state, "person", "tax_unit", how="value_from_first_person" + ) for _, r in soi_targets.iterrows(): lower, upper = r.AGI_LOWER_BOUND, r.AGI_UPPER_BOUND @@ -801,9 +850,13 @@ def _add_state_real_estate_taxes(loss_matrix, targets_list, sim): rtol=1e-8, ), "Real estate tax totals do not sum to national target" - targets_list.extend(real_estate_taxes_targets["real_estate_taxes_bn"].tolist()) + targets_list.extend( + real_estate_taxes_targets["real_estate_taxes_bn"].tolist() + ) - real_estate_taxes = sim.calculate("real_estate_taxes", map_to="household").values + real_estate_taxes = sim.calculate( + "real_estate_taxes", map_to="household" + ).values state = sim.calculate("state_code", map_to="household").values for _, r in real_estate_taxes_targets.iterrows(): @@ -826,16 +879,22 @@ def _add_snap_state_targets(sim): ).calibration.gov.cbo._children["snap"] ratio = snap_targets[["Cost"]].sum().values[0] / national_cost_target snap_targets[["CostAdj"]] = snap_targets[["Cost"]] / ratio - assert np.round(snap_targets[["CostAdj"]].sum().values[0]) == national_cost_target + assert ( + np.round(snap_targets[["CostAdj"]].sum().values[0]) + == national_cost_target + ) cost_targets = snap_targets.copy()[["GEO_ID", "CostAdj"]] - cost_targets["target_name"] = cost_targets["GEO_ID"].str[-4:] + "/snap-cost" + cost_targets["target_name"] = ( + cost_targets["GEO_ID"].str[-4:] + "/snap-cost" + ) hh_targets = snap_targets.copy()[["GEO_ID", "Households"]] hh_targets["target_name"] = snap_targets["GEO_ID"].str[-4:] + "/snap-hhs" target_names = ( - cost_targets["target_name"].tolist() + hh_targets["target_name"].tolist() + cost_targets["target_name"].tolist() + + hh_targets["target_name"].tolist() ) target_values = ( cost_targets["CostAdj"].astype(float).tolist() @@ -854,12 +913,14 @@ def _add_snap_metric_columns( snap_targets = pd.read_csv(CALIBRATION_FOLDER / "snap_state.csv") snap_cost = sim.calculate("snap_reported", map_to="household").values - snap_hhs = (sim.calculate("snap_reported", map_to="household").values > 0).astype( - int - ) + snap_hhs = ( + sim.calculate("snap_reported", map_to="household").values > 0 + ).astype(int) state = sim.calculate("state_code", map_to="person").values - state = sim.map_result(state, "person", "household", how="value_from_first_person") + state = sim.map_result( + state, "person", "household", how="value_from_first_person" + ) STATE_ABBR_TO_FIPS["DC"] = 11 state_fips = pd.Series(state).apply(lambda s: STATE_ABBR_TO_FIPS[s]) @@ -878,7 +939,9 @@ def _add_snap_metric_columns( return loss_matrix -def print_reweighting_diagnostics(optimised_weights, loss_matrix, targets_array, label): +def print_reweighting_diagnostics( + optimised_weights, loss_matrix, targets_array, label +): # Convert all inputs to NumPy arrays right at the start optimised_weights_np = ( optimised_weights.numpy() @@ -905,7 +968,9 @@ def print_reweighting_diagnostics(optimised_weights, loss_matrix, targets_array, # All subsequent calculations use the guaranteed NumPy versions estimate = optimised_weights_np @ loss_matrix_np - rel_error = (((estimate - targets_array_np) + 1) / (targets_array_np + 1)) ** 2 + rel_error = ( + ((estimate - targets_array_np) + 1) / (targets_array_np + 1) + ) ** 2 within_10_percent_mask = np.abs(estimate - targets_array_np) <= ( 0.10 * np.abs(targets_array_np) ) diff --git a/policyengine_us_data/utils/randomness.py b/policyengine_us_data/utils/randomness.py index 001dbf2f..eac01522 100644 --- a/policyengine_us_data/utils/randomness.py +++ b/policyengine_us_data/utils/randomness.py @@ -11,7 +11,9 @@ def _stable_string_hash(s: str) -> np.uint64: Ported from policyengine_core.commons.formulas._stable_string_hash. """ with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "overflow encountered", RuntimeWarning) + warnings.filterwarnings( + "ignore", "overflow encountered", RuntimeWarning + ) h = np.uint64(0) for byte in s.encode("utf-8"): h = h * np.uint64(31) + np.uint64(byte) diff --git a/policyengine_us_data/utils/soi.py b/policyengine_us_data/utils/soi.py index b9755c30..d9538add 100644 --- a/policyengine_us_data/utils/soi.py +++ b/policyengine_us_data/utils/soi.py @@ -11,7 +11,9 @@ def pe_to_soi(pe_dataset, year): pe_sim.default_calculation_period = year df = pd.DataFrame() - pe = lambda variable: np.array(pe_sim.calculate(variable, map_to="tax_unit")) + pe = lambda variable: np.array( + pe_sim.calculate(variable, map_to="tax_unit") + ) df["adjusted_gross_income"] = pe("adjusted_gross_income") df["exemption"] = pe("exemptions") @@ -49,8 +51,12 @@ def pe_to_soi(pe_dataset, year): df["total_pension_income"] = pe("pension_income") df["taxable_pension_income"] = pe("taxable_pension_income") df["qualified_dividends"] = pe("qualified_dividend_income") - df["rent_and_royalty_net_income"] = pe("rental_income") * (pe("rental_income") > 0) - df["rent_and_royalty_net_losses"] = -pe("rental_income") * (pe("rental_income") < 0) + df["rent_and_royalty_net_income"] = pe("rental_income") * ( + pe("rental_income") > 0 + ) + df["rent_and_royalty_net_losses"] = -pe("rental_income") * ( + pe("rental_income") < 0 + ) df["total_social_security"] = pe("social_security") df["taxable_social_security"] = pe("taxable_social_security") df["income_tax_before_credits"] = pe("income_tax_before_credits") @@ -170,7 +176,8 @@ def get_soi(year: int) -> pd.DataFrame: pe_name = uprating_map.get(variable) if pe_name in uprating.index: uprating_factors[variable] = ( - uprating.loc[pe_name, year] / uprating.loc[pe_name, soi.Year.max()] + uprating.loc[pe_name, year] + / uprating.loc[pe_name, soi.Year.max()] ) else: uprating_factors[variable] = ( @@ -211,7 +218,9 @@ def compare_soi_replication_to_soi(df, soi): elif fs == "Head of Household": subset = subset[subset.filing_status == "HEAD_OF_HOUSEHOLD"] elif fs == "Married Filing Jointly/Surviving Spouse": - subset = subset[subset.filing_status.isin(["JOINT", "SURVIVING_SPOUSE"])] + subset = subset[ + subset.filing_status.isin(["JOINT", "SURVIVING_SPOUSE"]) + ] elif fs == "Married Filing Separately": subset = subset[subset.filing_status == "SEPARATE"] @@ -249,13 +258,17 @@ def compare_soi_replication_to_soi(df, soi): } ) - soi_replication["Error"] = soi_replication["Value"] - soi_replication["SOI Value"] + soi_replication["Error"] = ( + soi_replication["Value"] - soi_replication["SOI Value"] + ) soi_replication["Absolute error"] = soi_replication["Error"].abs() soi_replication["Relative error"] = ( (soi_replication["Error"] / soi_replication["SOI Value"]) .replace([np.inf, -np.inf], np.nan) .fillna(0) ) - soi_replication["Absolute relative error"] = soi_replication["Relative error"].abs() + soi_replication["Absolute relative error"] = soi_replication[ + "Relative error" + ].abs() return soi_replication diff --git a/policyengine_us_data/utils/spm.py b/policyengine_us_data/utils/spm.py index ad3c9e9f..b2e4538b 100644 --- a/policyengine_us_data/utils/spm.py +++ b/policyengine_us_data/utils/spm.py @@ -44,7 +44,9 @@ def calculate_spm_thresholds_with_geoadj( for i in range(n): tenure_str = TENURE_CODE_MAP.get(int(tenure_codes[i]), "renter") base = base_thresholds[tenure_str] - equiv_scale = spm_equivalence_scale(int(num_adults[i]), int(num_children[i])) + equiv_scale = spm_equivalence_scale( + int(num_adults[i]), int(num_children[i]) + ) thresholds[i] = base * equiv_scale * geoadj[i] return thresholds diff --git a/policyengine_us_data/utils/uprating.py b/policyengine_us_data/utils/uprating.py index 41d223b0..6dd2f89c 100644 --- a/policyengine_us_data/utils/uprating.py +++ b/policyengine_us_data/utils/uprating.py @@ -23,7 +23,9 @@ def create_policyengine_uprating_factors_table(): parameter = system.parameters.get_child(variable.uprating) start_value = parameter(START_YEAR) for year in range(START_YEAR, END_YEAR + 1): - population_growth = population_size(year) / population_size(START_YEAR) + population_growth = population_size(year) / population_size( + START_YEAR + ) variable_names.append(variable.name) years.append(year) growth = parameter(year) / start_value diff --git a/tests/test_h6_reform.py b/tests/test_h6_reform.py index 2acdd8cc..e68ed8db 100644 --- a/tests/test_h6_reform.py +++ b/tests/test_h6_reform.py @@ -27,13 +27,17 @@ def calculate_oasdi_thresholds(year: int) -> tuple[int, int]: return oasdi_single, oasdi_joint -def get_swapped_thresholds(oasdi_threshold: int, hi_threshold: int) -> tuple[int, int]: +def get_swapped_thresholds( + oasdi_threshold: int, hi_threshold: int +) -> tuple[int, int]: """ Apply min/max swap to handle threshold crossover. Returns (base_threshold, adjusted_threshold) where base <= adjusted. """ - return min(oasdi_threshold, hi_threshold), max(oasdi_threshold, hi_threshold) + return min(oasdi_threshold, hi_threshold), max( + oasdi_threshold, hi_threshold + ) def needs_crossover_swap(oasdi_threshold: int, hi_threshold: int) -> bool: @@ -141,7 +145,9 @@ def test_single_crossover_starts_2046(self): # 2046+: crossover for year in range(2046, 2054): oasdi_single, _ = calculate_oasdi_thresholds(year) - assert needs_crossover_swap(oasdi_single, HI_SINGLE), f"Year {year}" + assert needs_crossover_swap( + oasdi_single, HI_SINGLE + ), f"Year {year}" class TestH6ThresholdSwapping: @@ -205,9 +211,9 @@ def test_2045_error_analysis(self): assert single_error_swapped == pytest.approx(225) assert joint_error_default == pytest.approx(3_150) - assert joint_error_default / single_error_swapped == pytest.approx(14.0), ( - "Swapped rates should have 14x less error" - ) + assert joint_error_default / single_error_swapped == pytest.approx( + 14.0 + ), "Swapped rates should have 14x less error" def test_swapped_rates_align_with_tax_cut_intent(self): """Swapped rates undertax (not overtax), aligning with reform intent.""" diff --git a/tests/test_no_formula_variables_stored.py b/tests/test_no_formula_variables_stored.py index 7c7cb0de..9334a5c7 100644 --- a/tests/test_no_formula_variables_stored.py +++ b/tests/test_no_formula_variables_stored.py @@ -109,7 +109,11 @@ def test_stored_values_match_computed( computed_total = np.sum(computed.astype(float)) if abs(stored_total) > 0: - pct_diff = abs(stored_total - computed_total) / abs(stored_total) * 100 + pct_diff = ( + abs(stored_total - computed_total) + / abs(stored_total) + * 100 + ) else: pct_diff = 0 @@ -137,13 +141,23 @@ def test_ss_subcomponents_sum_to_computed_total(sim, dataset_path): stored in the dataset sum to the simulation's computed total. """ with h5py.File(dataset_path, "r") as f: - ss_retirement = f["social_security_retirement"]["2024"][...].astype(float) - ss_disability = f["social_security_disability"]["2024"][...].astype(float) - ss_survivors = f["social_security_survivors"]["2024"][...].astype(float) - ss_dependents = f["social_security_dependents"]["2024"][...].astype(float) + ss_retirement = f["social_security_retirement"]["2024"][...].astype( + float + ) + ss_disability = f["social_security_disability"]["2024"][...].astype( + float + ) + ss_survivors = f["social_security_survivors"]["2024"][...].astype( + float + ) + ss_dependents = f["social_security_dependents"]["2024"][...].astype( + float + ) sub_sum = ss_retirement + ss_disability + ss_survivors + ss_dependents - computed_total = np.array(sim.calculate("social_security", 2024)).astype(float) + computed_total = np.array(sim.calculate("social_security", 2024)).astype( + float + ) # Only check records that have any SS income has_ss = computed_total > 0 diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 25755f0a..1ec097a7 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -144,9 +144,9 @@ def test_output_checksums(self): if file_path.exists() and filename != "checksums.txt": with open(file_path, "rb") as f: actual_checksum = hashlib.sha256(f.read()).hexdigest() - assert actual_checksum == expected_checksum, ( - f"Checksum mismatch for {filename}" - ) + assert ( + actual_checksum == expected_checksum + ), f"Checksum mismatch for {filename}" def test_memory_usage(self): """Test that memory usage stays within bounds.""" diff --git a/tests/test_weeks_unemployed.py b/tests/test_weeks_unemployed.py index d64d8b64..18aa4762 100644 --- a/tests/test_weeks_unemployed.py +++ b/tests/test_weeks_unemployed.py @@ -21,9 +21,9 @@ def test_lkweeks_in_person_columns(self): # Check for correct variable assert '"LKWEEKS"' in content, "LKWEEKS should be in PERSON_COLUMNS" - assert '"WKSUNEM"' not in content, ( - "WKSUNEM should not be in PERSON_COLUMNS (Census uses LKWEEKS)" - ) + assert ( + '"WKSUNEM"' not in content + ), "WKSUNEM should not be in PERSON_COLUMNS (Census uses LKWEEKS)" def test_cps_uses_lkweeks(self): """Test that cps.py uses LKWEEKS, not WKSUNEM.""" diff --git a/validation/benefit_validation.py b/validation/benefit_validation.py index cf468972..d614ae03 100644 --- a/validation/benefit_validation.py +++ b/validation/benefit_validation.py @@ -50,7 +50,9 @@ def analyze_benefit_underreporting(): # Participation participants = (benefit > 0).sum() - weighted_participants = ((benefit > 0) * weight).sum() / 1e6 # millions + weighted_participants = ( + (benefit > 0) * weight + ).sum() / 1e6 # millions # Underreporting factor underreporting = info["admin_total"] / total if total > 0 else np.inf @@ -166,7 +168,9 @@ def earnings_reform(parameters): earnings_change = earnings * pct_increase / 100 net_change = reformed_net - original_net - emtr = np.where(earnings_change > 0, 1 - (net_change / earnings_change), 0) + emtr = np.where( + earnings_change > 0, 1 - (net_change / earnings_change), 0 + ) # Focus on sample sample_emtr = emtr[sample] @@ -250,7 +254,9 @@ def analyze_aca_subsidies(): total_ptc = (ptc[mask] * weight[mask]).sum() / 1e9 recipients = ((ptc > 0) & mask).sum() weighted_recipients = (((ptc > 0) & mask) * weight).sum() / 1e6 - mean_ptc = ptc[(ptc > 0) & mask].mean() if ((ptc > 0) & mask).any() else 0 + mean_ptc = ( + ptc[(ptc > 0) & mask].mean() if ((ptc > 0) & mask).any() else 0 + ) results.append( { @@ -301,7 +307,9 @@ def generate_benefit_validation_report(): print("\n\n4. Top 10 States by SNAP Benefits") print("-" * 40) state_df = validate_state_benefits() - top_states = state_df.nlargest(10, "snap_billions")[["state_code", "snap_billions"]] + top_states = state_df.nlargest(10, "snap_billions")[ + ["state_code", "snap_billions"] + ] print(top_states.to_string(index=False)) # ACA analysis @@ -311,7 +319,9 @@ def generate_benefit_validation_report(): print(aca_df.to_string(index=False)) # Save results - underreporting_df.to_csv("validation/benefit_underreporting.csv", index=False) + underreporting_df.to_csv( + "validation/benefit_underreporting.csv", index=False + ) interactions_df.to_csv("validation/program_interactions.csv", index=False) emtr_df.to_csv("validation/effective_marginal_tax_rates.csv", index=False) state_df.to_csv("validation/state_benefit_totals.csv", index=False) diff --git a/validation/generate_qrf_statistics.py b/validation/generate_qrf_statistics.py index 4015fe1e..4a026dea 100644 --- a/validation/generate_qrf_statistics.py +++ b/validation/generate_qrf_statistics.py @@ -222,14 +222,18 @@ print(support_df.round(3).to_string()) print("\nSummary:") -print(f"- Average overlap coefficient: {support_df['overlap_coefficient'].mean():.3f}") +print( + f"- Average overlap coefficient: {support_df['overlap_coefficient'].mean():.3f}" +) print( f"- All overlap coefficients > 0.85: {(support_df['overlap_coefficient'] > 0.85).all()}" ) print( f"- Variables with SMD > 0.25: {(support_df['standardized_mean_diff'] > 0.25).sum()}" ) -print(f"- All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}") +print( + f"- All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}" +) print( f"- Variables with significant KS test (p<0.05): {(support_df['ks_pvalue'] < 0.05).sum()}" ) @@ -275,7 +279,9 @@ print( f"- All correlation differences < 0.05: {(joint_df['correlation_diff'] < 0.05).all()}" ) -print(f"- Average correlation difference: {joint_df['correlation_diff'].mean():.3f}") +print( + f"- Average correlation difference: {joint_df['correlation_diff'].mean():.3f}" +) # Save all results print("\n\nSAVING RESULTS...") @@ -288,7 +294,9 @@ ) accuracy_df.to_csv("validation/outputs/qrf_accuracy_metrics.csv") -print("✓ Saved accuracy metrics to validation/outputs/qrf_accuracy_metrics.csv") +print( + "✓ Saved accuracy metrics to validation/outputs/qrf_accuracy_metrics.csv" +) joint_df.to_csv("validation/outputs/joint_distribution_tests.csv", index=False) print( @@ -301,7 +309,9 @@ f.write("=" * 40 + "\n\n") for var, r2 in variance_explained.items(): f.write(f"{var.replace('_', ' ').title()}: {r2 * 100:.0f}%\n") -print("✓ Saved variance explained to validation/outputs/variance_explained.txt") +print( + "✓ Saved variance explained to validation/outputs/variance_explained.txt" +) # Create summary report with open("validation/outputs/qrf_diagnostics_summary.txt", "w") as f: @@ -317,8 +327,12 @@ f.write( f"All overlap coefficients > 0.85: {(support_df['overlap_coefficient'] > 0.85).all()}\n" ) - f.write(f"All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}\n") - f.write(f"All KS tests p > 0.05: {(support_df['ks_pvalue'] > 0.05).all()}\n\n") + f.write( + f"All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}\n" + ) + f.write( + f"All KS tests p > 0.05: {(support_df['ks_pvalue'] > 0.05).all()}\n\n" + ) f.write("2. VARIANCE EXPLAINED\n") f.write("-" * 40 + "\n") @@ -347,7 +361,9 @@ ) f.write("\n" + "=" * 60 + "\n") - f.write("These statistics demonstrate that the QRF methodology successfully:\n") + f.write( + "These statistics demonstrate that the QRF methodology successfully:\n" + ) f.write("- Maintains strong common support between datasets\n") f.write("- Achieves high predictive accuracy for imputation\n") f.write("- Preserves joint distributions of variables\n") diff --git a/validation/qrf_diagnostics.py b/validation/qrf_diagnostics.py index d22f883c..4e572916 100644 --- a/validation/qrf_diagnostics.py +++ b/validation/qrf_diagnostics.py @@ -28,7 +28,9 @@ def analyze_common_support(cps_data, puf_data, predictors): # Overlap coefficient (Weitzman 1970) # OVL = sum(min(f(x), g(x))) where f,g are densities - bins = np.histogram_bin_edges(np.concatenate([cps_dist, puf_dist]), bins=50) + bins = np.histogram_bin_edges( + np.concatenate([cps_dist, puf_dist]), bins=50 + ) cps_hist, _ = np.histogram(cps_dist, bins=bins, density=True) puf_hist, _ = np.histogram(puf_dist, bins=bins, density=True) @@ -79,7 +81,9 @@ def validate_qrf_accuracy(puf_data, predictors, target_vars, n_estimators=100): ) # Fit QRF - qrf = RandomForestQuantileRegressor(n_estimators=n_estimators, random_state=42) + qrf = RandomForestQuantileRegressor( + n_estimators=n_estimators, random_state=42 + ) qrf.fit(X_train, y_train) # Predictions at multiple quantiles @@ -120,7 +124,9 @@ def validate_qrf_accuracy(puf_data, predictors, target_vars, n_estimators=100): "qrf_rmse": rmse, "hotdeck_mae": hotdeck_mae, "linear_mae": lr_mae, - "qrf_improvement_vs_hotdeck": (hotdeck_mae - mae) / hotdeck_mae * 100, + "qrf_improvement_vs_hotdeck": (hotdeck_mae - mae) + / hotdeck_mae + * 100, "qrf_improvement_vs_linear": (lr_mae - mae) / lr_mae * 100, "coverage_90pct": coverage_90, "coverage_50pct": coverage_50, @@ -129,7 +135,9 @@ def validate_qrf_accuracy(puf_data, predictors, target_vars, n_estimators=100): return pd.DataFrame(results).T -def test_joint_distribution_preservation(original_data, imputed_data, var_pairs): +def test_joint_distribution_preservation( + original_data, imputed_data, var_pairs +): """Test whether joint distributions are preserved in imputation.""" results = [] @@ -151,12 +159,12 @@ def test_joint_distribution_preservation(original_data, imputed_data, var_pairs) # Joint distribution test (2D KS test approximation) # Using average of marginal KS statistics - ks1 = stats.ks_2samp(original_data[var1].dropna(), imputed_data[var1].dropna())[ - 0 - ] - ks2 = stats.ks_2samp(original_data[var2].dropna(), imputed_data[var2].dropna())[ - 0 - ] + ks1 = stats.ks_2samp( + original_data[var1].dropna(), imputed_data[var1].dropna() + )[0] + ks2 = stats.ks_2samp( + original_data[var2].dropna(), imputed_data[var2].dropna() + )[0] joint_ks = (ks1 + ks2) / 2 results.append( @@ -273,7 +281,9 @@ def generate_qrf_diagnostic_report(cps_data, puf_data, imputed_data): print( f"- Average QRF improvement vs linear: {accuracy_df['qrf_improvement_vs_linear'].mean():.1f}%" ) - print(f"- Average 90% coverage: {accuracy_df['coverage_90pct'].mean():.3f}") + print( + f"- Average 90% coverage: {accuracy_df['coverage_90pct'].mean():.3f}" + ) # Joint distribution preservation print("\n\n3. Joint Distribution Preservation") @@ -285,12 +295,16 @@ def generate_qrf_diagnostic_report(cps_data, puf_data, imputed_data): ("pension_income", "social_security"), ] - joint_df = test_joint_distribution_preservation(puf_data, imputed_data, var_pairs) + joint_df = test_joint_distribution_preservation( + puf_data, imputed_data, var_pairs + ) print(joint_df.to_string(index=False)) # Create diagnostic plots create_diagnostic_plots(cps_data, puf_data, predictors) - print("\n\nDiagnostic plots saved to validation/common_support_diagnostics.png") + print( + "\n\nDiagnostic plots saved to validation/common_support_diagnostics.png" + ) # Save results support_df.to_csv("validation/common_support_analysis.csv") diff --git a/validation/tax_policy_validation.py b/validation/tax_policy_validation.py index 9e04982f..c7c4f600 100644 --- a/validation/tax_policy_validation.py +++ b/validation/tax_policy_validation.py @@ -101,7 +101,9 @@ def analyze_high_income_taxpayers(): for threshold in thresholds: count = (weights[agi >= threshold]).sum() pct_returns = count / weights.sum() * 100 - total_agi = (agi[agi >= threshold] * weights[agi >= threshold]).sum() / 1e9 + total_agi = ( + agi[agi >= threshold] * weights[agi >= threshold] + ).sum() / 1e9 results.append( { @@ -133,7 +135,9 @@ def validate_state_revenues(): results.append({"state_code": state, "revenue_billions": total}) - return pd.DataFrame(results).sort_values("revenue_billions", ascending=False) + return pd.DataFrame(results).sort_values( + "revenue_billions", ascending=False + ) def generate_validation_report(): diff --git a/validation/validate_retirement_imputation.py b/validation/validate_retirement_imputation.py index 065a8294..6a11eafd 100644 --- a/validation/validate_retirement_imputation.py +++ b/validation/validate_retirement_imputation.py @@ -54,8 +54,12 @@ def validate_constraints(sim) -> list: issues = [] year = 2024 - emp_income = sim.calculate("employment_income", year, map_to="person").values - se_income = sim.calculate("self_employment_income", year, map_to="person").values + emp_income = sim.calculate( + "employment_income", year, map_to="person" + ).values + se_income = sim.calculate( + "self_employment_income", year, map_to="person" + ).values age = sim.calculate("age", year, map_to="person").values catch_up = age >= 50 @@ -75,7 +79,9 @@ def validate_constraints(sim) -> list: n_over_cap = (vals > max_401k + 1).sum() if n_over_cap > 0: - issues.append(f"FAIL: {var} has {n_over_cap} values exceeding 401k cap") + issues.append( + f"FAIL: {var} has {n_over_cap} values exceeding 401k cap" + ) zero_wage = emp_income == 0 n_nonzero_no_wage = (vals[zero_wage] > 0).sum() @@ -104,7 +110,9 @@ def validate_constraints(sim) -> list: n_over_cap = (vals > max_ira + 1).sum() if n_over_cap > 0: - issues.append(f"FAIL: {var} has {n_over_cap} values exceeding IRA cap") + issues.append( + f"FAIL: {var} has {n_over_cap} values exceeding IRA cap" + ) # SE pension constraint var = "self_employed_pension_contributions" @@ -133,7 +141,9 @@ def validate_aggregates(sim) -> list: weight = sim.calculate("person_weight", year).values - logger.info("\n%-45s %15s %15s %10s", "Variable", "Weighted Sum", "Target", "Ratio") + logger.info( + "\n%-45s %15s %15s %10s", "Variable", "Weighted Sum", "Target", "Ratio" + ) logger.info("-" * 90) for var, target in TARGETS.items(): From 7fb0a6a524bc64f06a8e301053295322138bb8ca Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 16 Mar 2026 20:52:59 +0100 Subject: [PATCH 7/8] Revert "Apply black formatting (79 char line length)" This reverts commit 69aa1f072ef7ec437623fb4e7288de9953955b01. --- .github/bump_version.py | 4 +- modal_app/data_build.py | 24 +-- modal_app/local_area.py | 56 ++---- modal_app/remote_calibration_runner.py | 40 +--- modal_app/worker_script.py | 4 +- paper/scripts/build_from_content.py | 36 +--- paper/scripts/calculate_target_performance.py | 3 +- paper/scripts/generate_all_tables.py | 8 +- paper/scripts/generate_validation_metrics.py | 4 +- paper/scripts/markdown_to_latex.py | 16 +- .../calibration/block_assignment.py | 32 +--- .../calibration/calibration_utils.py | 16 +- .../calibration/clone_and_assign.py | 4 +- .../calibration/county_assignment.py | 4 +- .../calibration/create_source_imputed_cps.py | 12 +- .../calibration/create_stratified_cps.py | 17 +- .../calibration/publish_local_area.py | 82 +++------ .../calibration/puf_impute.py | 50 ++--- .../calibration/sanity_checks.py | 4 +- .../calibration/source_impute.py | 47 ++--- .../calibration/stacked_dataset_builder.py | 4 +- .../calibration/unified_calibration.py | 43 ++--- .../calibration/unified_matrix_builder.py | 136 ++++---------- .../calibration/validate_national_h5.py | 4 +- .../calibration/validate_package.py | 24 +-- .../calibration/validate_staging.py | 28 +-- policyengine_us_data/datasets/acs/acs.py | 12 +- .../datasets/acs/census_acs.py | 22 +-- .../datasets/cps/census_cps.py | 32 +--- policyengine_us_data/datasets/cps/cps.py | 171 +++++------------- .../datasets/cps/enhanced_cps.py | 32 +--- .../check_calibrated_estimates_interactive.py | 66 +++---- .../cps/long_term/extract_ssa_costs.py | 4 +- .../cps/long_term/projection_utils.py | 16 +- .../cps/long_term/run_household_projection.py | 96 ++++------ .../datasets/cps/small_enhanced_cps.py | 15 +- policyengine_us_data/datasets/puf/irs_puf.py | 4 +- policyengine_us_data/datasets/puf/puf.py | 39 ++-- policyengine_us_data/datasets/scf/fed_scf.py | 16 +- policyengine_us_data/datasets/scf/scf.py | 36 +--- policyengine_us_data/datasets/sipp/sipp.py | 3 +- .../db/create_database_tables.py | 36 +--- .../db/create_initial_strata.py | 16 +- policyengine_us_data/db/etl_age.py | 8 +- policyengine_us_data/db/etl_irs_soi.py | 79 +++----- policyengine_us_data/db/etl_medicaid.py | 12 +- .../db/etl_national_targets.py | 52 ++---- policyengine_us_data/db/etl_pregnancy.py | 12 +- policyengine_us_data/db/etl_snap.py | 8 +- .../db/etl_state_income_tax.py | 10 +- policyengine_us_data/db/validate_database.py | 4 +- policyengine_us_data/db/validate_hierarchy.py | 52 ++---- policyengine_us_data/geography/__init__.py | 4 +- policyengine_us_data/geography/county_fips.py | 8 +- .../geography/create_zip_code_dataset.py | 4 +- policyengine_us_data/parameters/__init__.py | 4 +- .../calibration_targets/audit_county_enum.py | 4 +- .../make_block_cd_distributions.py | 8 +- .../make_block_crosswalk.py | 16 +- .../make_county_cd_distributions.py | 16 +- .../make_district_mapping.py | 8 +- .../pull_hardcoded_targets.py | 8 +- .../calibration_targets/pull_snap_targets.py | 8 +- .../calibration_targets/pull_soi_targets.py | 87 +++------ .../storage/upload_completed_datasets.py | 8 +- .../tests/test_calibration/conftest.py | 4 +- .../test_calibration/create_test_fixture.py | 32 +--- .../test_build_matrix_masking.py | 26 +-- .../test_calibration/test_clone_and_assign.py | 13 +- .../test_county_assignment.py | 8 +- .../tests/test_calibration/test_puf_impute.py | 8 +- .../test_retirement_imputation.py | 109 ++++------- .../test_calibration/test_source_impute.py | 4 +- .../test_stacked_dataset_builder.py | 58 +++--- .../test_calibration/test_target_config.py | 8 +- .../test_unified_calibration.py | 36 +--- .../test_unified_matrix_builder.py | 55 ++---- .../test_calibration/test_xw_consistency.py | 11 +- .../tests/test_constraint_validation.py | 12 +- .../tests/test_database_build.py | 28 ++- .../tests/test_datasets/test_county_fips.py | 8 +- .../tests/test_datasets/test_cps.py | 17 +- .../test_datasets/test_dataset_sanity.py | 50 ++--- .../tests/test_datasets/test_enhanced_cps.py | 54 ++---- .../tests/test_datasets/test_sipp_assets.py | 28 ++- .../test_datasets/test_small_enhanced_cps.py | 10 +- .../test_datasets/test_sparse_enhanced_cps.py | 28 +-- .../tests/test_format_comparison.py | 55 ++---- policyengine_us_data/tests/test_puf_impute.py | 4 +- .../tests/test_schema_views_and_lookups.py | 12 +- policyengine_us_data/utils/census.py | 4 +- .../utils/constraint_validation.py | 16 +- policyengine_us_data/utils/data_upload.py | 20 +- policyengine_us_data/utils/db.py | 29 +-- policyengine_us_data/utils/huggingface.py | 12 +- policyengine_us_data/utils/loss.py | 147 +++++---------- policyengine_us_data/utils/randomness.py | 4 +- policyengine_us_data/utils/soi.py | 27 +-- policyengine_us_data/utils/spm.py | 4 +- policyengine_us_data/utils/uprating.py | 4 +- tests/test_h6_reform.py | 18 +- tests/test_no_formula_variables_stored.py | 26 +-- tests/test_reproducibility.py | 6 +- tests/test_weeks_unemployed.py | 6 +- validation/benefit_validation.py | 20 +- validation/generate_qrf_statistics.py | 32 +--- validation/qrf_diagnostics.py | 40 ++-- validation/tax_policy_validation.py | 8 +- validation/validate_retirement_imputation.py | 20 +- 109 files changed, 834 insertions(+), 2025 deletions(-) diff --git a/.github/bump_version.py b/.github/bump_version.py index bb0fd6dd..779a82e3 100644 --- a/.github/bump_version.py +++ b/.github/bump_version.py @@ -19,9 +19,7 @@ def get_current_version(pyproject_path: Path) -> str: def infer_bump(changelog_dir: Path) -> str: fragments = [ - f - for f in changelog_dir.iterdir() - if f.is_file() and f.name != ".gitkeep" + f for f in changelog_dir.iterdir() if f.is_file() and f.name != ".gitkeep" ] if not fragments: print("No changelog fragments found", file=sys.stderr) diff --git a/modal_app/data_build.py b/modal_app/data_build.py index 2a0310c4..197ee32b 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -21,9 +21,7 @@ ) image = ( - modal.Image.debian_slim(python_version="3.13") - .apt_install("git") - .pip_install("uv") + modal.Image.debian_slim(python_version="3.13").apt_install("git").pip_install("uv") ) REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" @@ -92,9 +90,7 @@ def setup_gcp_credentials(): @functools.cache def get_current_commit() -> str: """Get the current git commit SHA (cached per process).""" - return subprocess.check_output( - ["git", "rev-parse", "HEAD"], text=True - ).strip() + return subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() def get_checkpoint_path(branch: str, output_file: str) -> Path: @@ -404,9 +400,7 @@ def build_datasets( print("=== Phase 3: Building extended CPS ===") run_script_with_checkpoint( "policyengine_us_data/datasets/cps/extended_cps.py", - SCRIPT_OUTPUTS[ - "policyengine_us_data/datasets/cps/extended_cps.py" - ], + SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/extended_cps.py"], branch, checkpoint_volume, env=env, @@ -414,17 +408,13 @@ def build_datasets( # GROUP 3: After extended_cps - run in parallel # enhanced_cps and stratified_cps both depend on extended_cps - print( - "=== Phase 4: Building enhanced and stratified CPS (parallel) ===" - ) + print("=== Phase 4: Building enhanced and stratified CPS (parallel) ===") with ThreadPoolExecutor(max_workers=2) as executor: futures = [ executor.submit( run_script_with_checkpoint, "policyengine_us_data/datasets/cps/enhanced_cps.py", - SCRIPT_OUTPUTS[ - "policyengine_us_data/datasets/cps/enhanced_cps.py" - ], + SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/enhanced_cps.py"], branch, checkpoint_volume, env=env, @@ -447,9 +437,7 @@ def build_datasets( print("=== Phase 5: Building small enhanced CPS ===") run_script_with_checkpoint( "policyengine_us_data/datasets/cps/small_enhanced_cps.py", - SCRIPT_OUTPUTS[ - "policyengine_us_data/datasets/cps/small_enhanced_cps.py" - ], + SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/small_enhanced_cps.py"], branch, checkpoint_volume, env=env, diff --git a/modal_app/local_area.py b/modal_app/local_area.py index f13ae216..0b0670d2 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -104,9 +104,7 @@ def validate_artifacts( artifacts = config.get("artifacts", {}) if not artifacts: - print( - "WARNING: No artifacts section in run config, skipping validation" - ) + print("WARNING: No artifacts section in run config, skipping validation") return for filename, expected_hash in artifacts.items(): @@ -128,9 +126,7 @@ def validate_artifacts( f" Actual: {actual}" ) - print( - f"Validated {len(artifacts)} artifact(s) against run config checksums" - ) + print(f"Validated {len(artifacts)} artifact(s) against run config checksums") def get_version() -> str: @@ -211,15 +207,11 @@ def run_phase( version_dir: Path, ) -> set: """Run a single build phase, spawning workers and collecting results.""" - work_chunks = partition_work( - states, districts, cities, num_workers, completed - ) + work_chunks = partition_work(states, districts, cities, num_workers, completed) total_remaining = sum(len(c) for c in work_chunks) print(f"\n--- Phase: {phase_name} ---") - print( - f"Remaining work: {total_remaining} items across {len(work_chunks)} workers" - ) + print(f"Remaining work: {total_remaining} items across {len(work_chunks)} workers") if total_remaining == 0: print(f"All {phase_name} items already built!") @@ -408,9 +400,7 @@ def validate_staging(branch: str, version: str) -> Dict: print(f" States: {manifest['totals']['states']}") print(f" Districts: {manifest['totals']['districts']}") print(f" Cities: {manifest['totals']['cities']}") - print( - f" Total size: {manifest['totals']['total_size_bytes'] / 1e9:.2f} GB" - ) + print(f" Total size: {manifest['totals']['total_size_bytes'] / 1e9:.2f} GB") return manifest @@ -569,7 +559,9 @@ def promote_publish(branch: str = "main", version: str = "") -> str: if result.returncode != 0: raise RuntimeError(f"Promote failed: {result.stderr}") - return f"Successfully promoted version {version} with {len(manifest['files'])} files" + return ( + f"Successfully promoted version {version} with {len(manifest['files'])} files" + ) @app.function( @@ -621,15 +613,11 @@ def coordinate_publish( "dataset": dataset_path, "database": db_path, "geography": (calibration_dir / "calibration" / "geography.npz"), - "run_config": ( - calibration_dir / "calibration" / "unified_run_config.json" - ), + "run_config": (calibration_dir / "calibration" / "unified_run_config.json"), } for label, p in required.items(): if not p.exists(): - raise RuntimeError( - f"Missing required calibration input ({label}): {p}" - ) + raise RuntimeError(f"Missing required calibration input ({label}): {p}") print("All required calibration inputs found on volume.") else: if calibration_dir.exists(): @@ -658,15 +646,11 @@ def coordinate_publish( print("Calibration inputs downloaded") dataset_path = ( - calibration_dir - / "calibration" - / "source_imputed_stratified_extended_cps.h5" + calibration_dir / "calibration" / "source_imputed_stratified_extended_cps.h5" ) geo_npz_path = calibration_dir / "calibration" / "geography.npz" - config_json_path = ( - calibration_dir / "calibration" / "unified_run_config.json" - ) + config_json_path = calibration_dir / "calibration" / "unified_run_config.json" calibration_inputs = { "weights": str(weights_path), "dataset": str(dataset_path), @@ -782,14 +766,10 @@ def coordinate_publish( ) if actual_total < expected_total: - print( - f"WARNING: Expected {expected_total} files, found {actual_total}" - ) + print(f"WARNING: Expected {expected_total} files, found {actual_total}") print("\nStarting upload to staging...") - result = upload_to_staging.remote( - branch=branch, version=version, manifest=manifest - ) + result = upload_to_staging.remote(branch=branch, version=version, manifest=manifest) print(result) print("\n" + "=" * 60) @@ -873,14 +853,10 @@ def coordinate_national_publish( staging_volume.commit() print("National calibration inputs downloaded") - weights_path = ( - calibration_dir / "calibration" / "national_calibration_weights.npy" - ) + weights_path = calibration_dir / "calibration" / "national_calibration_weights.npy" db_path = calibration_dir / "calibration" / "policy_data.db" dataset_path = ( - calibration_dir - / "calibration" - / "source_imputed_stratified_extended_cps.h5" + calibration_dir / "calibration" / "source_imputed_stratified_extended_cps.h5" ) geo_npz_path = calibration_dir / "calibration" / "national_geography.npz" diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 4853c719..075d5948 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -5,14 +5,10 @@ app = modal.App("policyengine-us-data-fit-weights") hf_secret = modal.Secret.from_name("huggingface-token") -calibration_vol = modal.Volume.from_name( - "calibration-data", create_if_missing=True -) +calibration_vol = modal.Volume.from_name("calibration-data", create_if_missing=True) image = ( - modal.Image.debian_slim(python_version="3.11") - .apt_install("git") - .pip_install("uv") + modal.Image.debian_slim(python_version="3.11").apt_install("git").pip_install("uv") ) REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" @@ -52,9 +48,7 @@ def _clone_and_install(branch: str): subprocess.run(["uv", "sync", "--extra", "l0"], check=True) -def _append_hyperparams( - cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq=None -): +def _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq=None): """Append optional hyperparameter flags to a command list.""" if beta is not None: cmd.extend(["--beta", str(beta)]) @@ -271,9 +265,7 @@ def _fit_weights_impl( cmd.append("--county-level") if workers > 1: cmd.extend(["--workers", str(workers)]) - _append_hyperparams( - cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq - ) + _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq) cal_rc, cal_lines = _run_streaming( cmd, @@ -330,9 +322,7 @@ def _fit_from_package_impl( ] if target_config: cmd.extend(["--target-config", target_config]) - _append_hyperparams( - cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq - ) + _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq) print(f"Running command: {' '.join(cmd)}", flush=True) @@ -347,9 +337,7 @@ def _fit_from_package_impl( return _collect_outputs(cal_lines) -def _print_provenance_from_meta( - meta: dict, current_branch: str = None -) -> None: +def _print_provenance_from_meta(meta: dict, current_branch: str = None) -> None: """Print provenance info and warn on branch mismatch.""" built = meta.get("created_at", "unknown") branch = meta.get("git_branch", "unknown") @@ -526,9 +514,7 @@ def check_volume_package() -> dict: return {"exists": False} stat = os.stat(pkg_path) - mtime = datetime.datetime.fromtimestamp( - stat.st_mtime, tz=datetime.timezone.utc - ) + mtime = datetime.datetime.fromtimestamp(stat.st_mtime, tz=datetime.timezone.utc) info = { "exists": True, "size": stat.st_size, @@ -1026,9 +1012,7 @@ def main( if vol_info.get("created_at") or vol_info.get("git_branch"): _print_provenance_from_meta(vol_info, branch) mode_label = ( - "national calibration" - if national - else "fitting from pre-built package" + "national calibration" if national else "fitting from pre-built package" ) print( "========================================", @@ -1121,12 +1105,8 @@ def main( upload_calibration_artifacts( weights_path=output, blocks_path=(blocks_output if result.get("blocks") else None), - geo_labels_path=( - geo_labels_output if result.get("geo_labels") else None - ), - geography_path=( - geography_output if result.get("geography") else None - ), + geo_labels_path=(geo_labels_output if result.get("geo_labels") else None), + geography_path=(geography_output if result.get("geography") else None), log_dir=".", prefix=prefix, ) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 3267d0fb..f36b59a0 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -94,9 +94,7 @@ def main(): if state_fips is None: raise ValueError(f"Unknown state code: {item_id}") cd_subset = [ - cd - for cd in cds_to_calibrate - if int(cd) // 100 == state_fips + cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips ] if not cd_subset: print( diff --git a/paper/scripts/build_from_content.py b/paper/scripts/build_from_content.py index 21068f0d..52f88389 100644 --- a/paper/scripts/build_from_content.py +++ b/paper/scripts/build_from_content.py @@ -47,12 +47,8 @@ def md_to_latex(self, content, section_type="section"): latex = re.sub(r"^# Abstract\n\n", "", latex) else: # Convert markdown headers to LaTeX sections - latex = re.sub( - r"^# (.+)$", r"\\section{\1}", latex, flags=re.MULTILINE - ) - latex = re.sub( - r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE - ) + latex = re.sub(r"^# (.+)$", r"\\section{\1}", latex, flags=re.MULTILINE) + latex = re.sub(r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE) latex = re.sub( r"^### (.+)$", r"\\subsubsection{\1}", @@ -173,15 +169,11 @@ def convert_citation(match): if len(author_list) == 1: # Handle "Author1 and Author2" format if " and " in authors: - first_author = ( - authors.split(" and ")[0].strip().split()[-1] - ) + first_author = authors.split(" and ")[0].strip().split()[-1] cite_key = f"{first_author.lower()}{year}" else: # Single author - author = ( - author_list[0].strip().split()[-1] - ) # Last name + author = author_list[0].strip().split()[-1] # Last name cite_key = f"{author.lower()}{year}" else: # Multiple authors - use first author @@ -191,9 +183,7 @@ def convert_citation(match): return f"\\citep{{{cite_key}}}" return match.group(0) # Return original if no year found - latex = re.sub( - r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_citation, latex - ) + latex = re.sub(r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_citation, latex) # Also handle inline citations like "Author (Year)" or "Author et al. (Year)" def convert_inline_citation(match): @@ -276,15 +266,11 @@ def convert_myst_citation(match): if len(author_list) == 1: # Handle "Author1 and Author2" format if " and " in authors: - first_author = ( - authors.split(" and ")[0].strip().split()[-1] - ) + first_author = authors.split(" and ")[0].strip().split()[-1] cite_key = f"{first_author.lower()}{year}" else: # Single author - author = ( - author_list[0].strip().split()[-1] - ) # Last name + author = author_list[0].strip().split()[-1] # Last name cite_key = f"{author.lower()}{year}" else: # Multiple authors - use first author @@ -294,9 +280,7 @@ def convert_myst_citation(match): return f"{{cite}}`{cite_key}`" return match.group(0) - myst = re.sub( - r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_myst_citation, myst - ) + myst = re.sub(r"\(([^)]+(?:19|20)\d{2}[a-z]?)\)", convert_myst_citation, myst) # Handle inline citations like "Author (Year)" - convert to {cite:t}`author_year` def convert_inline_myst(match): @@ -343,9 +327,7 @@ def process_content_file(self, content_file): # LaTeX conversion if stem == "abstract": latex_content = self.md_to_latex(content, section_type="abstract") - latex_content = ( - f"\\begin{{abstract}}\n{latex_content}\n\\end{{abstract}}" - ) + latex_content = f"\\begin{{abstract}}\n{latex_content}\n\\end{{abstract}}" latex_path = self.paper_dir / "abstract.tex" elif stem == "introduction": latex_content = self.md_to_latex(content) diff --git a/paper/scripts/calculate_target_performance.py b/paper/scripts/calculate_target_performance.py index 1a50ab3c..8f5a65f1 100644 --- a/paper/scripts/calculate_target_performance.py +++ b/paper/scripts/calculate_target_performance.py @@ -79,8 +79,7 @@ def compare_dataset_performance( # Calculate average improvement by target category categories = { - "IRS Income": lambda x: "employment_income" in x - or "capital_gains" in x, + "IRS Income": lambda x: "employment_income" in x or "capital_gains" in x, "Demographics": lambda x: "age_" in x or "population" in x, "Programs": lambda x: "snap" in x or "social_security" in x, "Tax Expenditures": lambda x: "salt" in x or "charitable" in x, diff --git a/paper/scripts/generate_all_tables.py b/paper/scripts/generate_all_tables.py index 8f476203..690b528d 100644 --- a/paper/scripts/generate_all_tables.py +++ b/paper/scripts/generate_all_tables.py @@ -33,9 +33,7 @@ def create_latex_table(df, caption, label, float_format=None): # Format the dataframe as LaTeX if float_format: - table_body = df.to_latex( - index=False, escape=False, float_format=float_format - ) + table_body = df.to_latex(index=False, escape=False, float_format=float_format) else: table_body = df.to_latex(index=False, escape=False) @@ -44,9 +42,7 @@ def create_latex_table(df, caption, label, float_format=None): tabular_start = next( i for i, line in enumerate(lines) if "\\begin{tabular}" in line ) - tabular_end = next( - i for i, line in enumerate(lines) if "\\end{tabular}" in line - ) + tabular_end = next(i for i, line in enumerate(lines) if "\\end{tabular}" in line) # Indent the tabular content for i in range(tabular_start, tabular_end + 1): diff --git a/paper/scripts/generate_validation_metrics.py b/paper/scripts/generate_validation_metrics.py index db586959..90b3624d 100644 --- a/paper/scripts/generate_validation_metrics.py +++ b/paper/scripts/generate_validation_metrics.py @@ -235,9 +235,7 @@ def main(): print(f"\nResults saved to {results_dir}/") print("\nNOTE: All metrics marked as [TO BE CALCULATED] require full") - print( - "dataset generation and microsimulation runs to compute actual values." - ) + print("dataset generation and microsimulation runs to compute actual values.") if __name__ == "__main__": diff --git a/paper/scripts/markdown_to_latex.py b/paper/scripts/markdown_to_latex.py index 5c3b0e3b..7cc80b04 100644 --- a/paper/scripts/markdown_to_latex.py +++ b/paper/scripts/markdown_to_latex.py @@ -24,12 +24,8 @@ def convert_markdown_to_latex(markdown_content: str) -> str: # Convert headers latex = re.sub(r"^# (.+)$", r"\\section{\1}", latex, flags=re.MULTILINE) - latex = re.sub( - r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE - ) - latex = re.sub( - r"^### (.+)$", r"\\subsubsection{\1}", latex, flags=re.MULTILINE - ) + latex = re.sub(r"^## (.+)$", r"\\subsection{\1}", latex, flags=re.MULTILINE) + latex = re.sub(r"^### (.+)$", r"\\subsubsection{\1}", latex, flags=re.MULTILINE) # Convert bold and italic latex = re.sub(r"\*\*(.+?)\*\*", r"\\textbf{\1}", latex) @@ -67,9 +63,7 @@ def convert_markdown_to_latex(markdown_content: str) -> str: # Manage list stack while len(list_stack) > indent_level + 1: - new_lines.append( - " " * (len(list_stack) - 1) + "\\end{itemize}" - ) + new_lines.append(" " * (len(list_stack) - 1) + "\\end{itemize}") list_stack.pop() if len(list_stack) <= indent_level: @@ -81,9 +75,7 @@ def convert_markdown_to_latex(markdown_content: str) -> str: else: # Close any open lists while list_stack: - new_lines.append( - " " * (len(list_stack) - 1) + "\\end{itemize}" - ) + new_lines.append(" " * (len(list_stack) - 1) + "\\end{itemize}") list_stack.pop() new_lines.append(line) in_list = False diff --git a/policyengine_us_data/calibration/block_assignment.py b/policyengine_us_data/calibration/block_assignment.py index 3ce09289..83af388f 100644 --- a/policyengine_us_data/calibration/block_assignment.py +++ b/policyengine_us_data/calibration/block_assignment.py @@ -138,7 +138,9 @@ def _load_cbsa_crosswalk() -> Dict[str, str]: Returns: Dict mapping 5-digit county FIPS to CBSA code (or None if not in CBSA) """ - url = "https://data.nber.org/cbsa-csa-fips-county-crosswalk/2023/cbsa2fipsxw_2023.csv" + url = ( + "https://data.nber.org/cbsa-csa-fips-county-crosswalk/2023/cbsa2fipsxw_2023.csv" + ) try: df = pd.read_csv(url, dtype=str) # Build 5-digit county FIPS from state + county codes @@ -270,14 +272,10 @@ def get_all_geography_from_block(block_geoid: str) -> Dict[str, Optional[str]]: result = { "sldu": row["sldu"] if pd.notna(row["sldu"]) else None, "sldl": row["sldl"] if pd.notna(row["sldl"]) else None, - "place_fips": ( - row["place_fips"] if pd.notna(row["place_fips"]) else None - ), + "place_fips": (row["place_fips"] if pd.notna(row["place_fips"]) else None), "vtd": row["vtd"] if pd.notna(row["vtd"]) else None, "puma": row["puma"] if pd.notna(row["puma"]) else None, - "zcta": ( - row["zcta"] if has_zcta and pd.notna(row["zcta"]) else None - ), + "zcta": (row["zcta"] if has_zcta and pd.notna(row["zcta"]) else None), } return result return { @@ -446,17 +444,11 @@ def assign_geography_for_cd( - county_index: int32 indices into County enum (for backwards compat) """ # Assign blocks first - block_geoids = assign_blocks_for_cd( - cd_geoid, n_households, seed, distributions - ) + block_geoids = assign_blocks_for_cd(cd_geoid, n_households, seed, distributions) # Derive geography directly from block GEOID structure - county_fips = np.array( - [get_county_fips_from_block(b) for b in block_geoids] - ) - tract_geoids = np.array( - [get_tract_geoid_from_block(b) for b in block_geoids] - ) + county_fips = np.array([get_county_fips_from_block(b) for b in block_geoids]) + tract_geoids = np.array([get_tract_geoid_from_block(b) for b in block_geoids]) state_fips = np.array([get_state_fips_from_block(b) for b in block_geoids]) # CBSA lookup via county (may be None for rural areas) @@ -533,12 +525,8 @@ def derive_geography_from_blocks( Returns: Dict with same keys as assign_geography_for_cd. """ - county_fips = np.array( - [get_county_fips_from_block(b) for b in block_geoids] - ) - tract_geoids = np.array( - [get_tract_geoid_from_block(b) for b in block_geoids] - ) + county_fips = np.array([get_county_fips_from_block(b) for b in block_geoids]) + tract_geoids = np.array([get_tract_geoid_from_block(b) for b in block_geoids]) state_fips = np.array([get_state_fips_from_block(b) for b in block_geoids]) cbsa_codes = np.array([get_cbsa_from_county(c) or "" for c in county_fips]) county_indices = np.array( diff --git a/policyengine_us_data/calibration/calibration_utils.py b/policyengine_us_data/calibration/calibration_utils.py index 5cf9f1bc..9d10ee6a 100644 --- a/policyengine_us_data/calibration/calibration_utils.py +++ b/policyengine_us_data/calibration/calibration_utils.py @@ -352,9 +352,7 @@ def create_target_groups( for domain_var, var_name in pairs: var_mask = ( - (targets_df["variable"] == var_name) - & level_mask - & ~processed_mask + (targets_df["variable"] == var_name) & level_mask & ~processed_mask ) if has_domain and domain_var is not None: var_mask &= targets_df["domain_variable"] == domain_var @@ -380,15 +378,11 @@ def create_target_groups( # Format output based on level and count if n_targets == 1: value = matching["value"].iloc[0] - info_str = ( - f"{level_name} {label} (1 target, value={value:,.0f})" - ) + info_str = f"{level_name} {label} (1 target, value={value:,.0f})" print_str = f" Group {group_id}: {label} = {value:,.0f}" else: info_str = f"{level_name} {label} ({n_targets} targets)" - print_str = ( - f" Group {group_id}: {label} ({n_targets} targets)" - ) + print_str = f" Group {group_id}: {label} ({n_targets} targets)" group_info.append(f"Group {group_id}: {info_str}") print(print_str) @@ -628,9 +622,7 @@ def calculate_spm_thresholds_vectorized( for i in range(n_units): tenure_str = TENURE_CODE_MAP.get(int(tenure_codes[i]), "renter") base = base_thresholds[tenure_str] - equiv_scale = spm_equivalence_scale( - int(num_adults[i]), int(num_children[i]) - ) + equiv_scale = spm_equivalence_scale(int(num_adults[i]), int(num_children[i])) thresholds[i] = base * equiv_scale * spm_unit_geoadj[i] return thresholds diff --git a/policyengine_us_data/calibration/clone_and_assign.py b/policyengine_us_data/calibration/clone_and_assign.py index bc85dfd8..a140f1b1 100644 --- a/policyengine_us_data/calibration/clone_and_assign.py +++ b/policyengine_us_data/calibration/clone_and_assign.py @@ -110,9 +110,7 @@ def assign_random_geography( n_bad = collisions.sum() if n_bad == 0: break - clone_indices[collisions] = rng.choice( - len(blocks), size=n_bad, p=probs - ) + clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs) clone_cds = cds[clone_indices] collisions = np.zeros(n_records, dtype=bool) for prev in range(clone_idx): diff --git a/policyengine_us_data/calibration/county_assignment.py b/policyengine_us_data/calibration/county_assignment.py index 6d32d30b..a1f262d7 100644 --- a/policyengine_us_data/calibration/county_assignment.py +++ b/policyengine_us_data/calibration/county_assignment.py @@ -150,9 +150,7 @@ def get_county_filter_probability( else: dist = _generate_uniform_distribution(cd_key) - return sum( - prob for county, prob in dist.items() if county in county_filter - ) + return sum(prob for county, prob in dist.items() if county in county_filter) def get_filtered_county_distribution( diff --git a/policyengine_us_data/calibration/create_source_imputed_cps.py b/policyengine_us_data/calibration/create_source_imputed_cps.py index 4381f72d..68dd876a 100644 --- a/policyengine_us_data/calibration/create_source_imputed_cps.py +++ b/policyengine_us_data/calibration/create_source_imputed_cps.py @@ -19,9 +19,7 @@ logger = logging.getLogger(__name__) INPUT_PATH = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") -OUTPUT_PATH = str( - STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" -) +OUTPUT_PATH = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") def create_source_imputed_cps( @@ -49,9 +47,7 @@ def create_source_imputed_cps( logger.info("Loaded %d households, time_period=%d", n_records, time_period) - geography = assign_random_geography( - n_records=n_records, n_clones=1, seed=seed - ) + geography = assign_random_geography(n_records=n_records, n_clones=1, seed=seed) base_states = geography.state_fips[:n_records] raw_data = sim.dataset.load_dataset() @@ -59,9 +55,7 @@ def create_source_imputed_cps( for var in raw_data: val = raw_data[var] if isinstance(val, dict): - data_dict[var] = { - int(k) if k.isdigit() else k: v for k, v in val.items() - } + data_dict[var] = {int(k) if k.isdigit() else k: v for k, v in val.items()} else: data_dict[var] = {time_period: val[...]} diff --git a/policyengine_us_data/calibration/create_stratified_cps.py b/policyengine_us_data/calibration/create_stratified_cps.py index e2632366..2aa15a9f 100644 --- a/policyengine_us_data/calibration/create_stratified_cps.py +++ b/policyengine_us_data/calibration/create_stratified_cps.py @@ -79,9 +79,7 @@ def create_stratified_cps_dataset( f" Top {100 - high_income_percentile}% (AGI >= ${high_income_threshold:,.0f}): {n_top:,}" ) print(f" Middle 25-{high_income_percentile}%: {n_middle:,}") - print( - f" Bottom 25% (AGI < ${bottom_25_pct_threshold:,.0f}): {n_bottom_25:,}" - ) + print(f" Bottom 25% (AGI < ${bottom_25_pct_threshold:,.0f}): {n_bottom_25:,}") # Calculate sampling rates # Keep ALL top earners, distribute remaining quota between middle and bottom @@ -132,9 +130,7 @@ def create_stratified_cps_dataset( # Top earners - keep all top_mask = agi >= high_income_threshold selected_mask[top_mask] = True - print( - f" Top {100 - high_income_percentile}%: selected {np.sum(top_mask):,}" - ) + print(f" Top {100 - high_income_percentile}%: selected {np.sum(top_mask):,}") # Bottom 25% bottom_mask = agi < bottom_25_pct_threshold @@ -271,10 +267,7 @@ def create_stratified_cps_dataset( if "person_id" in f and str(time_period) in f["person_id"]: person_ids = f["person_id"][str(time_period)][:] print(f" Final persons: {len(person_ids):,}") - if ( - "household_weight" in f - and str(time_period) in f["household_weight"] - ): + if "household_weight" in f and str(time_period) in f["household_weight"]: weights = f["household_weight"][str(time_period)][:] print(f" Final household weights sum: {np.sum(weights):,.0f}") @@ -342,7 +335,5 @@ def create_stratified_cps_dataset( ) print("\nExamples:") print(" python create_stratified_cps.py 30000") - print( - " python create_stratified_cps.py 50000 --top=99.5 --oversample-poor" - ) + print(" python create_stratified_cps.py 50000 --top=99.5 --oversample-poor") print(" python create_stratified_cps.py 30000 --seed=123 # reproducible") diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index 8e505351..ad3ccb1a 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -168,17 +168,14 @@ def build_h5( # CD subset filtering: zero out cells whose CD isn't in subset if cd_subset is not None: cd_subset_set = set(cd_subset) - cd_mask = np.vectorize(lambda cd: cd in cd_subset_set)( - clone_cds_matrix - ) + cd_mask = np.vectorize(lambda cd: cd in cd_subset_set)(clone_cds_matrix) W[~cd_mask] = 0 # County filtering: scale weights by P(target_counties | CD) if county_filter is not None: unique_cds = np.unique(clone_cds_matrix) cd_prob = { - cd: get_county_filter_probability(cd, county_filter) - for cd in unique_cds + cd: get_county_filter_probability(cd, county_filter) for cd in unique_cds } p_matrix = np.vectorize( cd_prob.__getitem__, @@ -205,15 +202,11 @@ def build_h5( ) clone_weights = W[active_geo, active_hh] active_blocks = blocks.reshape(n_clones_total, n_hh)[active_geo, active_hh] - active_clone_cds = clone_cds.reshape(n_clones_total, n_hh)[ - active_geo, active_hh - ] + active_clone_cds = clone_cds.reshape(n_clones_total, n_hh)[active_geo, active_hh] empty_count = np.sum(active_blocks == "") if empty_count > 0: - raise ValueError( - f"{empty_count} active clones have empty block GEOIDs" - ) + raise ValueError(f"{empty_count} active clones have empty block GEOIDs") print(f"Active clones: {n_clones:,}") print(f"Total weight: {clone_weights.sum():,.0f}") @@ -258,16 +251,12 @@ def build_h5( # === Build clone index arrays === hh_clone_idx = active_hh - persons_per_clone = np.array( - [len(hh_to_persons.get(h, [])) for h in active_hh] - ) + persons_per_clone = np.array([len(hh_to_persons.get(h, [])) for h in active_hh]) person_parts = [ np.array(hh_to_persons.get(h, []), dtype=np.int64) for h in active_hh ] person_clone_idx = ( - np.concatenate(person_parts) - if person_parts - else np.array([], dtype=np.int64) + np.concatenate(person_parts) if person_parts else np.array([], dtype=np.int64) ) entity_clone_idx = {} @@ -276,8 +265,7 @@ def build_h5( epc = np.array([len(hh_to_entity[ek].get(h, [])) for h in active_hh]) entities_per_clone[ek] = epc parts = [ - np.array(hh_to_entity[ek].get(h, []), dtype=np.int64) - for h in active_hh + np.array(hh_to_entity[ek].get(h, []), dtype=np.int64) for h in active_hh ] entity_clone_idx[ek] = ( np.concatenate(parts) if parts else np.array([], dtype=np.int64) @@ -316,9 +304,7 @@ def build_h5( sorted_keys = entity_keys[sorted_order] sorted_new = new_entity_ids[ek][sorted_order] - p_old_eids = person_entity_id_arrays[ek][person_clone_idx].astype( - np.int64 - ) + p_old_eids = person_entity_id_arrays[ek][person_clone_idx].astype(np.int64) person_keys = clone_ids_for_persons * offset + p_old_eids positions = np.searchsorted(sorted_keys, person_keys) @@ -473,9 +459,7 @@ def build_h5( } # === Gap 4: Congressional district GEOID === - clone_cd_geoids = np.array( - [int(cd) for cd in active_clone_cds], dtype=np.int32 - ) + clone_cd_geoids = np.array([int(cd) for cd in active_clone_cds], dtype=np.int32) data["congressional_district_geoid"] = { time_period: clone_cd_geoids, } @@ -495,9 +479,7 @@ def build_h5( ) # Get cloned person ages and SPM unit IDs - person_ages = sim.calculate("age", map_to="person").values[ - person_clone_idx - ] + person_ages = sim.calculate("age", map_to="person").values[person_clone_idx] # Get cloned tenure types spm_tenure_holder = sim.get_holder("spm_unit_tenure_type") @@ -589,9 +571,7 @@ def build_h5( print(f"Total population (person weights): {pw.sum():,.0f}") # === HDFStore output (entity-level format) === - entity_dfs = split_data_into_entity_dfs( - data, sim.tax_benefit_system, time_period - ) + entity_dfs = split_data_into_entity_dfs(data, sim.tax_benefit_system, time_period) manifest_df = build_uprating_manifest(data, sim.tax_benefit_system) save_hdfstore(entity_dfs, manifest_df, str(output_path), time_period) @@ -659,9 +639,7 @@ def build_states( if upload: print(f"Uploading {state_code}.h5 to GCP...") - upload_local_area_file( - str(output_path), "states", skip_hf=True - ) + upload_local_area_file(str(output_path), "states", skip_hf=True) # Upload HDFStore file if it exists hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") @@ -679,9 +657,7 @@ def build_states( print(f"Completed {state_code}") if upload and len(hf_queue) >= hf_batch_size: - print( - f"\nUploading batch of {len(hf_queue)} files to HuggingFace..." - ) + print(f"\nUploading batch of {len(hf_queue)} files to HuggingFace...") upload_local_area_batch_to_hf(hf_queue) hf_queue = [] @@ -690,9 +666,7 @@ def build_states( raise if upload and hf_queue: - print( - f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..." - ) + print(f"\nUploading final batch of {len(hf_queue)} files to HuggingFace...") upload_local_area_batch_to_hf(hf_queue) @@ -744,9 +718,7 @@ def build_districts( if upload: print(f"Uploading {friendly_name}.h5 to GCP...") - upload_local_area_file( - str(output_path), "districts", skip_hf=True - ) + upload_local_area_file(str(output_path), "districts", skip_hf=True) # Upload HDFStore file if it exists hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") @@ -764,9 +736,7 @@ def build_districts( print(f"Completed {friendly_name}") if upload and len(hf_queue) >= hf_batch_size: - print( - f"\nUploading batch of {len(hf_queue)} files to HuggingFace..." - ) + print(f"\nUploading batch of {len(hf_queue)} files to HuggingFace...") upload_local_area_batch_to_hf(hf_queue) hf_queue = [] @@ -775,9 +745,7 @@ def build_districts( raise if upload and hf_queue: - print( - f"\nUploading final batch of {len(hf_queue)} files to HuggingFace..." - ) + print(f"\nUploading final batch of {len(hf_queue)} files to HuggingFace...") upload_local_area_batch_to_hf(hf_queue) @@ -824,14 +792,10 @@ def build_cities( if upload: print("Uploading NYC.h5 to GCP...") - upload_local_area_file( - str(output_path), "cities", skip_hf=True - ) + upload_local_area_file(str(output_path), "cities", skip_hf=True) # Upload HDFStore file if it exists - hdfstore_path = str(output_path).replace( - ".h5", ".hdfstore.h5" - ) + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") if os.path.exists(hdfstore_path): print("Uploading NYC.hdfstore.h5 to GCP...") upload_local_area_file( @@ -850,9 +814,7 @@ def build_cities( raise if upload and hf_queue: - print( - f"\nUploading batch of {len(hf_queue)} city files to HuggingFace..." - ) + print(f"\nUploading batch of {len(hf_queue)} city files to HuggingFace...") upload_local_area_batch_to_hf(hf_queue) @@ -929,9 +891,7 @@ def main(): elif args.skip_download: inputs = { "weights": WORK_DIR / "calibration_weights.npy", - "dataset": ( - WORK_DIR / "source_imputed_stratified_extended_cps.h5" - ), + "dataset": (WORK_DIR / "source_imputed_stratified_extended_cps.h5"), } print("Using existing files in work directory:") for key, path in inputs.items(): diff --git a/policyengine_us_data/calibration/puf_impute.py b/policyengine_us_data/calibration/puf_impute.py index dfdada5f..445bd758 100644 --- a/policyengine_us_data/calibration/puf_impute.py +++ b/policyengine_us_data/calibration/puf_impute.py @@ -194,9 +194,7 @@ "social_security", ] -RETIREMENT_PREDICTORS = ( - RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS -) +RETIREMENT_PREDICTORS = RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS def _get_retirement_limits(year: int) -> dict: @@ -411,9 +409,7 @@ def reconcile_ss_subcomponents( if puf_has_ss.any(): shares = _qrf_ss_shares(data, n_cps, time_period, puf_has_ss) if shares is None: - shares = _age_heuristic_ss_shares( - data, n_cps, time_period, puf_has_ss - ) + shares = _age_heuristic_ss_shares(data, n_cps, time_period, puf_has_ss) for sub in SS_SUBCOMPONENTS: if sub not in data: @@ -492,17 +488,13 @@ def _map_to_entity(pred_values, variable_name): return pred_values entity = var_meta.entity.key if entity != "person": - return cps_sim.populations[entity].value_from_first_person( - pred_values - ) + return cps_sim.populations[entity].value_from_first_person(pred_values) return pred_values # Impute weeks_unemployed for PUF half puf_weeks = None if y_full is not None and dataset_path is not None: - puf_weeks = _impute_weeks_unemployed( - data, y_full, time_period, dataset_path - ) + puf_weeks = _impute_weeks_unemployed(data, y_full, time_period, dataset_path) # Impute retirement contributions for PUF half puf_retirement = None @@ -526,24 +518,14 @@ def _map_to_entity(pred_values, variable_name): time_period: np.concatenate([values, values + values.max()]) } elif "_weight" in variable: - new_data[variable] = { - time_period: np.concatenate([values, values * 0]) - } + new_data[variable] = {time_period: np.concatenate([values, values * 0])} elif variable == "weeks_unemployed" and puf_weeks is not None: - new_data[variable] = { - time_period: np.concatenate([values, puf_weeks]) - } - elif ( - variable in CPS_RETIREMENT_VARIABLES and puf_retirement is not None - ): + new_data[variable] = {time_period: np.concatenate([values, puf_weeks])} + elif variable in CPS_RETIREMENT_VARIABLES and puf_retirement is not None: puf_vals = puf_retirement[variable] - new_data[variable] = { - time_period: np.concatenate([values, puf_vals]) - } + new_data[variable] = {time_period: np.concatenate([values, puf_vals])} else: - new_data[variable] = { - time_period: np.concatenate([values, values]) - } + new_data[variable] = {time_period: np.concatenate([values, values])} new_data["state_fips"] = { time_period: np.concatenate([state_fips, state_fips]).astype(np.int32) @@ -656,11 +638,7 @@ def _impute_weeks_unemployed( logger.info( "Imputed weeks_unemployed for PUF: %d with weeks > 0, mean = %.1f", (imputed_weeks > 0).sum(), - ( - imputed_weeks[imputed_weeks > 0].mean() - if (imputed_weeks > 0).any() - else 0 - ), + (imputed_weeks[imputed_weeks > 0].mean() if (imputed_weeks > 0).any() else 0), ) return imputed_weeks @@ -822,9 +800,7 @@ def _run_qrf_imputation( puf_sim = Microsimulation(dataset=puf_dataset) - puf_agi = puf_sim.calculate( - "adjusted_gross_income", map_to="person" - ).values + puf_agi = puf_sim.calculate("adjusted_gross_income", map_to="person").values X_train_full = puf_sim.calculate_dataframe( DEMOGRAPHIC_PREDICTORS + IMPUTED_VARIABLES @@ -901,9 +877,7 @@ def _stratified_subsample_index( if remaining_quota >= len(bottom_idx): selected_bottom = bottom_idx else: - selected_bottom = rng.choice( - bottom_idx, size=remaining_quota, replace=False - ) + selected_bottom = rng.choice(bottom_idx, size=remaining_quota, replace=False) selected = np.concatenate([top_idx, selected_bottom]) selected.sort() diff --git a/policyengine_us_data/calibration/sanity_checks.py b/policyengine_us_data/calibration/sanity_checks.py index 0ea59218..e1f59064 100644 --- a/policyengine_us_data/calibration/sanity_checks.py +++ b/policyengine_us_data/calibration/sanity_checks.py @@ -214,9 +214,7 @@ def _get(f, path): { "check": "per_hh_employment_income", "status": "WARN", - "detail": ( - f"${per_hh:,.0f}/hh (expected $10K-$200K)" - ), + "detail": (f"${per_hh:,.0f}/hh (expected $10K-$200K)"), } ) else: diff --git a/policyengine_us_data/calibration/source_impute.py b/policyengine_us_data/calibration/source_impute.py index 339e038e..25c7975a 100644 --- a/policyengine_us_data/calibration/source_impute.py +++ b/policyengine_us_data/calibration/source_impute.py @@ -225,9 +225,7 @@ def _person_state_fips( if hh_ids_person is not None: hh_ids = data["household_id"][time_period] hh_to_idx = {int(hh_id): i for i, hh_id in enumerate(hh_ids)} - return np.array( - [state_fips[hh_to_idx[int(hh_id)]] for hh_id in hh_ids_person] - ) + return np.array([state_fips[hh_to_idx[int(hh_id)]] for hh_id in hh_ids_person]) # Fallback: distribute persons across households as evenly # as possible (first households get any remainder). n_hh = len(data["household_id"][time_period]) @@ -264,9 +262,9 @@ def _impute_acs( predictors = ACS_PREDICTORS + ["state_fips"] acs_df = acs.calculate_dataframe(ACS_PREDICTORS + ACS_IMPUTED_VARIABLES) - acs_df["state_fips"] = acs.calculate( - "state_fips", map_to="person" - ).values.astype(np.float32) + acs_df["state_fips"] = acs.calculate("state_fips", map_to="person").values.astype( + np.float32 + ) train_df = acs_df[acs_df.is_household_head].sample(10_000, random_state=42) train_df = _encode_tenure_type(train_df) @@ -368,16 +366,10 @@ def _impute_sipp( sipp_df["is_under_18"] = sipp_df.TAGE < 18 sipp_df["is_under_6"] = sipp_df.TAGE < 6 sipp_df["count_under_18"] = ( - sipp_df.groupby("SSUID")["is_under_18"] - .sum() - .loc[sipp_df.SSUID.values] - .values + sipp_df.groupby("SSUID")["is_under_18"].sum().loc[sipp_df.SSUID.values].values ) sipp_df["count_under_6"] = ( - sipp_df.groupby("SSUID")["is_under_6"] - .sum() - .loc[sipp_df.SSUID.values] - .values + sipp_df.groupby("SSUID")["is_under_6"].sum().loc[sipp_df.SSUID.values].values ) tip_cols = [ @@ -408,9 +400,9 @@ def _impute_sipp( age_df = pd.DataFrame({"hh": hh_ids_person, "age": person_ages}) under_18 = age_df.groupby("hh")["age"].apply(lambda x: (x < 18).sum()) under_6 = age_df.groupby("hh")["age"].apply(lambda x: (x < 6).sum()) - cps_tip_df["count_under_18"] = under_18.loc[ - hh_ids_person - ].values.astype(np.float32) + cps_tip_df["count_under_18"] = under_18.loc[hh_ids_person].values.astype( + np.float32 + ) cps_tip_df["count_under_6"] = under_6.loc[hh_ids_person].values.astype( np.float32 ) @@ -499,10 +491,7 @@ def _impute_sipp( asset_train.index, size=min(20_000, len(asset_train)), replace=True, - p=( - asset_train.household_weight - / asset_train.household_weight.sum() - ), + p=(asset_train.household_weight / asset_train.household_weight.sum()), ) ] @@ -513,15 +502,15 @@ def _impute_sipp( ["employment_income", "age", "is_male"], ) if "is_male" in cps_asset_df.columns: - cps_asset_df["is_female"] = ( - ~cps_asset_df["is_male"].astype(bool) - ).astype(np.float32) + cps_asset_df["is_female"] = (~cps_asset_df["is_male"].astype(bool)).astype( + np.float32 + ) else: cps_asset_df["is_female"] = 0.0 if "is_married" in data: - cps_asset_df["is_married"] = data["is_married"][ - time_period - ].astype(np.float32) + cps_asset_df["is_married"] = data["is_married"][time_period].astype( + np.float32 + ) else: cps_asset_df["is_married"] = 0.0 cps_asset_df["count_under_18"] = ( @@ -623,9 +612,7 @@ def _impute_scf( cps_df = _build_cps_receiver(data, time_period, dataset_path, pe_vars) if "is_male" in cps_df.columns: - cps_df["is_female"] = (~cps_df["is_male"].astype(bool)).astype( - np.float32 - ) + cps_df["is_female"] = (~cps_df["is_male"].astype(bool)).astype(np.float32) else: cps_df["is_female"] = 0.0 diff --git a/policyengine_us_data/calibration/stacked_dataset_builder.py b/policyengine_us_data/calibration/stacked_dataset_builder.py index 172f05fb..0089f0d1 100644 --- a/policyengine_us_data/calibration/stacked_dataset_builder.py +++ b/policyengine_us_data/calibration/stacked_dataset_builder.py @@ -105,9 +105,7 @@ f"{geography.n_records} records" ) - print( - f"Geography: {geography.n_clones} clones x {geography.n_records} records" - ) + print(f"Geography: {geography.n_clones} clones x {geography.n_records} records") takeup_filter = [spec["variable"] for spec in SIMPLE_TAKEUP_VARS] diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 361f0dba..66bc1f9b 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -136,9 +136,7 @@ def check_package_staleness(metadata: dict) -> None: built_dt = datetime.datetime.fromisoformat(created) age = datetime.datetime.now() - built_dt if age.days > 7: - print( - f"WARNING: Package is {age.days} days old (built {created})" - ) + print(f"WARNING: Package is {age.days} days old (built {created})") except Exception: pass @@ -171,9 +169,7 @@ def check_package_staleness(metadata: dict) -> None: def parse_args(argv=None): - parser = argparse.ArgumentParser( - description="Unified L0 calibration pipeline" - ) + parser = argparse.ArgumentParser(description="Unified L0 calibration pipeline") parser.add_argument( "--dataset", default=None, @@ -342,9 +338,7 @@ def _match_rules(targets_df, rules): for rule in rules: rule_mask = targets_df["variable"] == rule["variable"] if "geo_level" in rule: - rule_mask = rule_mask & ( - targets_df["geo_level"] == rule["geo_level"] - ) + rule_mask = rule_mask & (targets_df["geo_level"] == rule["geo_level"]) if "domain_variable" in rule: rule_mask = rule_mask & ( targets_df["domain_variable"] == rule["domain_variable"] @@ -584,9 +578,7 @@ def fit_l0_weights( import torch - os.environ.setdefault( - "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True" - ) + os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") n_total = X_sparse.shape[1] if initial_weights is None: @@ -629,9 +621,7 @@ def _flushed_print(*args, **kwargs): builtins.print = _flushed_print enable_logging = ( - log_freq is not None - and log_path is not None - and target_names is not None + log_freq is not None and log_path is not None and target_names is not None ) if enable_logging: Path(log_path).parent.mkdir(parents=True, exist_ok=True) @@ -668,9 +658,7 @@ def _flushed_print(*args, **kwargs): with torch.no_grad(): y_pred = model.predict(X_sparse).cpu().numpy() - weights_snap = ( - model.get_weights(deterministic=True).cpu().numpy() - ) + weights_snap = model.get_weights(deterministic=True).cpu().numpy() active_w = weights_snap[weights_snap > 0] nz = len(active_w) @@ -714,9 +702,7 @@ def _flushed_print(*args, **kwargs): flush=True, ) - ach_flags = ( - achievable if achievable is not None else [True] * len(targets) - ) + ach_flags = achievable if achievable is not None else [True] * len(targets) with open(log_path, "a") as f: for i in range(len(targets)): est = y_pred[i] @@ -987,8 +973,7 @@ def run_calibration( ) source_path = str( - Path(dataset_path).parent - / f"source_imputed_{Path(dataset_path).stem}.h5" + Path(dataset_path).parent / f"source_imputed_{Path(dataset_path).stem}.h5" ) with h5py.File(source_path, "w") as f: for var, time_dict in data_dict.items(): @@ -1189,9 +1174,7 @@ def main(argv=None): f"Dataset not found: {dataset_path}\n" "Run 'make data' first, or pass --dataset with a valid path." ) - db_path = args.db_path or str( - STORAGE_FOLDER / "calibration" / "policy_data.db" - ) + db_path = args.db_path or str(STORAGE_FOLDER / "calibration" / "policy_data.db") output_path = args.output or str( STORAGE_FOLDER / "calibration" / "calibration_weights.npy" ) @@ -1205,15 +1188,11 @@ def main(argv=None): domain_variables = None if args.domain_variables: - domain_variables = [ - x.strip() for x in args.domain_variables.split(",") - ] + domain_variables = [x.strip() for x in args.domain_variables.split(",")] hierarchical_domains = None if args.hierarchical_domains: - hierarchical_domains = [ - x.strip() for x in args.hierarchical_domains.split(",") - ] + hierarchical_domains = [x.strip() for x in args.hierarchical_domains.split(",")] t_start = time.time() diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index de80d015..04d785ff 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -124,9 +124,7 @@ def _compute_single_state( if rerandomize_takeup: for spec in SIMPLE_TAKEUP_VARS: entity = spec["entity"] - n_ent = len( - state_sim.calculate(f"{entity}_id", map_to=entity).values - ) + n_ent = len(state_sim.calculate(f"{entity}_id", map_to=entity).values) state_sim.set_input( spec["variable"], time_period, @@ -252,9 +250,7 @@ def _compute_single_state_group_counties( if rerandomize_takeup: for spec in SIMPLE_TAKEUP_VARS: entity = spec["entity"] - n_ent = len( - state_sim.calculate(f"{entity}_id", map_to=entity).values - ) + n_ent = len(state_sim.calculate(f"{entity}_id", map_to=entity).values) state_sim.set_input( spec["variable"], time_period, @@ -327,9 +323,7 @@ def _assemble_clone_values_standalone( state_masks = {int(s): clone_states == s for s in unique_clone_states} unique_person_states = np.unique(person_states) - person_state_masks = { - int(s): person_states == s for s in unique_person_states - } + person_state_masks = {int(s): person_states == s for s in unique_person_states} county_masks = {} unique_counties = None if clone_counties is not None and county_values: @@ -746,18 +740,10 @@ def _build_entity_relationship(self, sim) -> pd.DataFrame: self._entity_rel_cache = pd.DataFrame( { - "person_id": sim.calculate( - "person_id", map_to="person" - ).values, - "household_id": sim.calculate( - "household_id", map_to="person" - ).values, - "tax_unit_id": sim.calculate( - "tax_unit_id", map_to="person" - ).values, - "spm_unit_id": sim.calculate( - "spm_unit_id", map_to="person" - ).values, + "person_id": sim.calculate("person_id", map_to="person").values, + "household_id": sim.calculate("household_id", map_to="person").values, + "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, + "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, } ) return self._entity_rel_cache @@ -877,9 +863,7 @@ def _build_state_values( except Exception as exc: for f in futures: f.cancel() - raise RuntimeError( - f"State {st} failed: {exc}" - ) from exc + raise RuntimeError(f"State {st} failed: {exc}") from exc else: from policyengine_us import Microsimulation from policyengine_us_data.utils.takeup import ( @@ -935,9 +919,7 @@ def _build_state_values( for spec in SIMPLE_TAKEUP_VARS: entity = spec["entity"] n_ent = len( - state_sim.calculate( - f"{entity}_id", map_to=entity - ).values + state_sim.calculate(f"{entity}_id", map_to=entity).values ) state_sim.set_input( spec["variable"], @@ -1120,9 +1102,7 @@ def _build_county_values( except Exception as exc: for f in futures: f.cancel() - raise RuntimeError( - f"State group {sf} failed: {exc}" - ) from exc + raise RuntimeError(f"State group {sf} failed: {exc}") from exc else: from policyengine_us import Microsimulation from policyengine_us_data.utils.takeup import ( @@ -1298,9 +1278,7 @@ def _assemble_clone_values( # Pre-compute masks to avoid recomputing per variable state_masks = {int(s): clone_states == s for s in unique_clone_states} unique_person_states = np.unique(person_states) - person_state_masks = { - int(s): person_states == s for s in unique_person_states - } + person_state_masks = {int(s): person_states == s for s in unique_person_states} county_masks = {} unique_counties = None if clone_counties is not None and county_values: @@ -1313,9 +1291,7 @@ def _assemble_clone_values( continue if var in cdv and county_values and clone_counties is not None: first_county = unique_counties[0] - if var not in county_values.get(first_county, {}).get( - "hh", {} - ): + if var not in county_values.get(first_county, {}).get("hh", {}): continue arr = np.empty(n_records, dtype=np.float32) for county in unique_counties: @@ -1457,9 +1433,7 @@ def _calculate_uprating_factors(self, params) -> dict: factors[(from_year, "cpi")] = 1.0 try: - pop_from = params.calibration.gov.census.populations.total( - from_year - ) + pop_from = params.calibration.gov.census.populations.total(from_year) pop_to = params.calibration.gov.census.populations.total( self.time_period ) @@ -1536,9 +1510,7 @@ def _get_state_uprating_factors( var_factors[var] = 1.0 continue period = row.iloc[0]["period"] - factor, _ = self._get_uprating_info( - var, period, national_factors - ) + factor, _ = self._get_uprating_info(var, period, national_factors) var_factors[var] = factor result[state_int] = var_factors @@ -1673,9 +1645,7 @@ def _make_target_name( non_geo = [c for c in constraints if c["variable"] not in _GEO_VARS] if non_geo: - strs = [ - f"{c['variable']}{c['operation']}{c['value']}" for c in non_geo - ] + strs = [f"{c['variable']}{c['operation']}{c['value']}" for c in non_geo] parts.append("[" + ",".join(strs) + "]") return "/".join(parts) @@ -1819,15 +1789,9 @@ def build_matrix( n_targets = len(targets_df) # 2. Sort targets by geographic level - targets_df["_geo_level"] = targets_df["geographic_id"].apply( - get_geo_level - ) - targets_df = targets_df.sort_values( - ["_geo_level", "variable", "geographic_id"] - ) - targets_df = targets_df.drop(columns=["_geo_level"]).reset_index( - drop=True - ) + targets_df["_geo_level"] = targets_df["geographic_id"].apply(get_geo_level) + targets_df = targets_df.sort_values(["_geo_level", "variable", "geographic_id"]) + targets_df = targets_df.drop(columns=["_geo_level"]).reset_index(drop=True) # 3. Build column index structures from geography state_col_lists: Dict[int, list] = defaultdict(list) @@ -1854,9 +1818,7 @@ def build_matrix( geo_id = row["geographic_id"] target_geo_info.append((geo_level, geo_id)) - non_geo = [ - c for c in constraints if c["variable"] not in _GEO_VARS - ] + non_geo = [c for c in constraints if c["variable"] not in _GEO_VARS] non_geo_constraints_list.append(non_geo) target_names.append( @@ -1895,14 +1857,10 @@ def build_matrix( # 5c. State-independent structures (computed once) entity_rel = self._build_entity_relationship(sim) - household_ids = sim.calculate( - "household_id", map_to="household" - ).values + household_ids = sim.calculate("household_id", map_to="household").values person_hh_ids = sim.calculate("household_id", map_to="person").values hh_id_to_idx = {int(hid): idx for idx, hid in enumerate(household_ids)} - person_hh_indices = np.array( - [hh_id_to_idx[int(hid)] for hid in person_hh_ids] - ) + person_hh_indices = np.array([hh_id_to_idx[int(hid)] for hid in person_hh_ids]) tax_benefit_system = sim.tax_benefit_system # Pre-extract entity keys so workers don't need @@ -1910,9 +1868,7 @@ def build_matrix( variable_entity_map: Dict[str, str] = {} for var in unique_variables: if var.endswith("_count") and var in tax_benefit_system.variables: - variable_entity_map[var] = tax_benefit_system.variables[ - var - ].entity.key + variable_entity_map[var] = tax_benefit_system.variables[var].entity.key # 5c-extra: Entity-to-household index maps for takeup affected_target_info = {} @@ -1927,9 +1883,7 @@ def build_matrix( # Build entity-to-household index arrays spm_to_hh_id = ( - entity_rel.groupby("spm_unit_id")["household_id"] - .first() - .to_dict() + entity_rel.groupby("spm_unit_id")["household_id"].first().to_dict() ) spm_ids = sim.calculate("spm_unit_id", map_to="spm_unit").values spm_hh_idx = np.array( @@ -1937,9 +1891,7 @@ def build_matrix( ) tu_to_hh_id = ( - entity_rel.groupby("tax_unit_id")["household_id"] - .first() - .to_dict() + entity_rel.groupby("tax_unit_id")["household_id"].first().to_dict() ) tu_ids = sim.calculate("tax_unit_id", map_to="tax_unit").values tu_hh_idx = np.array( @@ -1958,9 +1910,7 @@ def build_matrix( f"{entity_level}_id", map_to=entity_level, ).values - ent_id_to_idx = { - int(eid): idx for idx, eid in enumerate(ent_ids) - } + ent_id_to_idx = {int(eid): idx for idx, eid in enumerate(ent_ids)} person_ent_ids = entity_rel[f"{entity_level}_id"].values entity_to_person_idx[entity_level] = np.array( [ent_id_to_idx[int(eid)] for eid in person_ent_ids] @@ -1983,9 +1933,7 @@ def build_matrix( for tvar, info in affected_target_info.items(): rk = info["rate_key"] if rk not in precomputed_rates: - precomputed_rates[rk] = load_take_up_rate( - rk, self.time_period - ) + precomputed_rates[rk] = load_take_up_rate(rk, self.time_period) # Store for post-optimization stacked takeup self.entity_hh_idx_map = entity_hh_idx_map @@ -2086,9 +2034,7 @@ def build_matrix( except Exception as exc: for f in futures: f.cancel() - raise RuntimeError( - f"Clone {ci} failed: {exc}" - ) from exc + raise RuntimeError(f"Clone {ci} failed: {exc}") from exc else: # ---- Sequential clone processing (unchanged) ---- @@ -2155,9 +2101,7 @@ def build_matrix( ent_counties = clone_counties[ent_hh] for cfips in np.unique(ent_counties): m = ent_counties == cfips - cv = county_values.get(cfips, {}).get( - "entity", {} - ) + cv = county_values.get(cfips, {}).get("entity", {}) if tvar in cv: ent_eligible[m] = cv[tvar][m] else: @@ -2182,9 +2126,7 @@ def build_matrix( ent_hh_ids, ) - ent_values = (ent_eligible * ent_takeup).astype( - np.float32 - ) + ent_values = (ent_eligible * ent_takeup).astype(np.float32) hh_result = np.zeros(n_records, dtype=np.float32) np.add.at(hh_result, ent_hh, ent_values) @@ -2244,17 +2186,15 @@ def build_matrix( constraint_key, ) if vkey not in count_cache: - count_cache[vkey] = ( - _calculate_target_values_standalone( - target_variable=variable, - non_geo_constraints=non_geo, - n_households=n_records, - hh_vars=hh_vars, - person_vars=person_vars, - entity_rel=entity_rel, - household_ids=household_ids, - variable_entity_map=variable_entity_map, - ) + count_cache[vkey] = _calculate_target_values_standalone( + target_variable=variable, + non_geo_constraints=non_geo, + n_households=n_records, + hh_vars=hh_vars, + person_vars=person_vars, + entity_rel=entity_rel, + household_ids=household_ids, + variable_entity_map=variable_entity_map, ) values = count_cache[vkey] else: diff --git a/policyengine_us_data/calibration/validate_national_h5.py b/policyengine_us_data/calibration/validate_national_h5.py index ba303812..cbe22796 100644 --- a/policyengine_us_data/calibration/validate_national_h5.py +++ b/policyengine_us_data/calibration/validate_national_h5.py @@ -145,7 +145,9 @@ def main(argv=None): icon = ( "PASS" if r["status"] == "PASS" - else "FAIL" if r["status"] == "FAIL" else "WARN" + else "FAIL" + if r["status"] == "FAIL" + else "WARN" ) print(f" [{icon}] {r['check']}: {r['detail']}") diff --git a/policyengine_us_data/calibration/validate_package.py b/policyengine_us_data/calibration/validate_package.py index 4321fbf8..c8ed16bc 100644 --- a/policyengine_us_data/calibration/validate_package.py +++ b/policyengine_us_data/calibration/validate_package.py @@ -85,9 +85,7 @@ def validate_package( ) k = min(n_hardest, len(ratios)) hardest_local_idx = np.argpartition(ratios, k)[:k] - hardest_local_idx = hardest_local_idx[ - np.argsort(ratios[hardest_local_idx]) - ] + hardest_local_idx = hardest_local_idx[np.argsort(ratios[hardest_local_idx])] hardest_global_idx = achievable_idx[hardest_local_idx] hardest_targets = pd.DataFrame( @@ -96,9 +94,7 @@ def validate_package( "domain_variable": targets_df["domain_variable"] .iloc[hardest_global_idx] .values, - "variable": targets_df["variable"] - .iloc[hardest_global_idx] - .values, + "variable": targets_df["variable"].iloc[hardest_global_idx].values, "geographic_id": targets_df["geographic_id"] .iloc[hardest_global_idx] .values, @@ -190,9 +186,7 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: lines.append(", ".join(parts)) lines.append("") - pct = ( - 100 * result.n_achievable / result.n_targets if result.n_targets else 0 - ) + pct = 100 * result.n_achievable / result.n_targets if result.n_targets else 0 pct_imp = 100 - pct lines.append("--- Achievability ---") lines.append( @@ -206,9 +200,7 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: if len(result.impossible_targets) > 0: lines.append("--- Impossible Targets ---") for _, row in result.impossible_targets.iterrows(): - lines.append( - f" {row['target_name']:<60s} {row['target_value']:>14,.0f}" - ) + lines.append(f" {row['target_name']:<60s} {row['target_value']:>14,.0f}") lines.append("") if len(result.impossible_by_group) > 1: @@ -265,9 +257,7 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: f" targets below ratio {result.strict_ratio})" ) elif result.n_impossible > 0: - lines.append( - f"RESULT: FAIL ({result.n_impossible} impossible targets)" - ) + lines.append(f"RESULT: FAIL ({result.n_impossible} impossible targets)") else: lines.append("RESULT: PASS") @@ -275,9 +265,7 @@ def format_report(result: ValidationResult, package_path: str = None) -> str: def main(): - parser = argparse.ArgumentParser( - description="Validate a calibration package" - ) + parser = argparse.ArgumentParser(description="Validate a calibration package") parser.add_argument( "path", nargs="?", diff --git a/policyengine_us_data/calibration/validate_staging.py b/policyengine_us_data/calibration/validate_staging.py index 4ecea143..be2f908d 100644 --- a/policyengine_us_data/calibration/validate_staging.py +++ b/policyengine_us_data/calibration/validate_staging.py @@ -178,9 +178,7 @@ def _batch_stratum_constraints(engine, stratum_ids) -> dict: df = pd.read_sql(query, conn) result = {} for sid, group in df.groupby("stratum_id"): - result[int(sid)] = group[["variable", "operation", "value"]].to_dict( - "records" - ) + result[int(sid)] = group[["variable", "operation", "value"]].to_dict("records") for sid in stratum_ids: result.setdefault(int(sid), []) return result @@ -264,15 +262,9 @@ def _build_entity_rel(sim) -> pd.DataFrame: return pd.DataFrame( { "person_id": sim.calculate("person_id", map_to="person").values, - "household_id": sim.calculate( - "household_id", map_to="person" - ).values, - "tax_unit_id": sim.calculate( - "tax_unit_id", map_to="person" - ).values, - "spm_unit_id": sim.calculate( - "spm_unit_id", map_to="person" - ).values, + "household_id": sim.calculate("household_id", map_to="person").values, + "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, + "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, } ) @@ -712,9 +704,7 @@ def _run_state_via_districts( variable = row_data["variable"] stratum_id = int(row_data["stratum_id"]) constraints = constraints_map.get(stratum_id, []) - target_name = UnifiedMatrixBuilder._make_target_name( - variable, constraints - ) + target_name = UnifiedMatrixBuilder._make_target_name(variable, constraints) per_district_rows.append( { @@ -745,9 +735,7 @@ def _run_state_via_districts( stratum_id = int(row_data["stratum_id"]) constraints = constraints_map.get(stratum_id, []) - target_name = UnifiedMatrixBuilder._make_target_name( - variable, constraints - ) + target_name = UnifiedMatrixBuilder._make_target_name(variable, constraints) error = sim_value - target_value abs_error = abs(error) @@ -758,9 +746,7 @@ def _run_state_via_districts( rel_error = float("inf") if error != 0 else 0.0 rel_abs_error = float("inf") if abs_error != 0 else 0.0 - sanity_check, sanity_reason = _run_sanity_check( - sim_value, variable, "state" - ) + sanity_check, sanity_reason = _run_sanity_check(sim_value, variable, "state") summary_rows.append( { diff --git a/policyengine_us_data/datasets/acs/acs.py b/policyengine_us_data/datasets/acs/acs.py index 0ecd3ee7..11d1ef73 100644 --- a/policyengine_us_data/datasets/acs/acs.py +++ b/policyengine_us_data/datasets/acs/acs.py @@ -18,9 +18,7 @@ def generate(self) -> None: raw_data = self.census_acs(require=True).load() acs = h5py.File(self.file_path, mode="w") - person, household = [ - raw_data[entity] for entity in ("person", "household") - ] + person, household = [raw_data[entity] for entity in ("person", "household")] self.add_id_variables(acs, person, household) self.add_person_variables(acs, person, household) @@ -39,9 +37,7 @@ def add_id_variables( h_id_to_number = pd.Series( np.arange(len(household)), index=household["SERIALNO"] ) - household["household_id"] = h_id_to_number[ - household["SERIALNO"] - ].values + household["household_id"] = h_id_to_number[household["SERIALNO"]].values person["household_id"] = h_id_to_number[person["SERIALNO"]].values person["person_id"] = person.index + 1 @@ -100,9 +96,7 @@ def add_spm_variables(acs: h5py.File, spm_unit: DataFrame) -> None: @staticmethod def add_household_variables(acs: h5py.File, household: DataFrame) -> None: acs["household_vehicles_owned"] = household.VEH - acs["state_fips"] = acs["household_state_fips"] = household.ST.astype( - int - ) + acs["state_fips"] = acs["household_state_fips"] = household.ST.astype(int) class ACS_2022(ACS): diff --git a/policyengine_us_data/datasets/acs/census_acs.py b/policyengine_us_data/datasets/acs/census_acs.py index 842af627..7bd28bd6 100644 --- a/policyengine_us_data/datasets/acs/census_acs.py +++ b/policyengine_us_data/datasets/acs/census_acs.py @@ -66,9 +66,7 @@ def generate(self) -> None: household = self.process_household_data( household_url, "psam_hus", HOUSEHOLD_COLUMNS ) - person = self.process_person_data( - person_url, "psam_pus", PERSON_COLUMNS - ) + person = self.process_person_data(person_url, "psam_pus", PERSON_COLUMNS) person = person[person.SERIALNO.isin(household.SERIALNO)] household = household[household.SERIALNO.isin(person.SERIALNO)] storage["household"] = household @@ -106,9 +104,7 @@ def process_household_data( return res @staticmethod - def process_person_data( - url: str, prefix: str, columns: List[str] - ) -> pd.DataFrame: + def process_person_data(url: str, prefix: str, columns: List[str]) -> pd.DataFrame: req = requests.get(url, stream=True) with BytesIO() as f: pbar = tqdm() @@ -137,9 +133,7 @@ def process_person_data( return res @staticmethod - def create_spm_unit_table( - storage: pd.HDFStore, person: pd.DataFrame - ) -> None: + def create_spm_unit_table(storage: pd.HDFStore, person: pd.DataFrame) -> None: SPM_UNIT_COLUMNS = [ "CAPHOUSESUB", "CAPWKCCXPNS", @@ -181,12 +175,10 @@ def create_spm_unit_table( # Ensure SERIALNO is treated as string JOIN_COLUMNS = ["SERIALNO", "SPORDER"] - original_person_table["SERIALNO"] = original_person_table[ - "SERIALNO" - ].astype(str) - original_person_table["SPORDER"] = original_person_table[ - "SPORDER" - ].astype(int) + original_person_table["SERIALNO"] = original_person_table["SERIALNO"].astype( + str + ) + original_person_table["SPORDER"] = original_person_table["SPORDER"].astype(int) person["SERIALNO"] = person["SERIALNO"].astype(str) person["SPORDER"] = person["SPORDER"].astype(int) diff --git a/policyengine_us_data/datasets/cps/census_cps.py b/policyengine_us_data/datasets/cps/census_cps.py index 00ca020e..042fefe5 100644 --- a/policyengine_us_data/datasets/cps/census_cps.py +++ b/policyengine_us_data/datasets/cps/census_cps.py @@ -15,9 +15,7 @@ class CensusCPS(Dataset): def generate(self): if self._cps_download_url is None: - raise ValueError( - f"No raw CPS data URL known for year {self.time_period}." - ) + raise ValueError(f"No raw CPS data URL known for year {self.time_period}.") url = self._cps_download_url @@ -28,9 +26,7 @@ def generate(self): ] response = requests.get(url, stream=True) - total_size_in_bytes = int( - response.headers.get("content-length", 200e6) - ) + total_size_in_bytes = int(response.headers.get("content-length", 200e6)) progress_bar = tqdm( total=total_size_in_bytes, unit="iB", @@ -38,9 +34,7 @@ def generate(self): desc="Downloading ASEC", ) if response.status_code == 404: - raise FileNotFoundError( - "Received a 404 response when fetching the data." - ) + raise FileNotFoundError("Received a 404 response when fetching the data.") with BytesIO() as file: content_length_actual = 0 for data in response.iter_content(int(1e6)): @@ -65,33 +59,23 @@ def generate(self): file_prefix = "cpspb/asec/prod/data/2019/" else: file_prefix = "" - with zipfile.open( - f"{file_prefix}pppub{file_year_code}.csv" - ) as f: + with zipfile.open(f"{file_prefix}pppub{file_year_code}.csv") as f: storage["person"] = pd.read_csv( f, - usecols=PERSON_COLUMNS - + spm_unit_columns - + TAX_UNIT_COLUMNS, + usecols=PERSON_COLUMNS + spm_unit_columns + TAX_UNIT_COLUMNS, ).fillna(0) person = storage["person"] - with zipfile.open( - f"{file_prefix}ffpub{file_year_code}.csv" - ) as f: + with zipfile.open(f"{file_prefix}ffpub{file_year_code}.csv") as f: person_family_id = person.PH_SEQ * 10 + person.PF_SEQ family = pd.read_csv(f).fillna(0) family_id = family.FH_SEQ * 10 + family.FFPOS family = family[family_id.isin(person_family_id)] storage["family"] = family - with zipfile.open( - f"{file_prefix}hhpub{file_year_code}.csv" - ) as f: + with zipfile.open(f"{file_prefix}hhpub{file_year_code}.csv") as f: person_household_id = person.PH_SEQ household = pd.read_csv(f).fillna(0) household_id = household.H_SEQ - household = household[ - household_id.isin(person_household_id) - ] + household = household[household_id.isin(person_household_id)] storage["household"] = household storage["tax_unit"] = self._create_tax_unit_table(person) storage["spm_unit"] = self._create_spm_unit_table( diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index 3ec1f769..418d7396 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -93,9 +93,7 @@ def downsample(self, frac: float): # Store original dtypes before modifying original_data: dict = self.load_dataset() - original_dtypes = { - key: original_data[key].dtype for key in original_data - } + original_dtypes = {key: original_data[key].dtype for key in original_data} sim = Microsimulation(dataset=self) sim.subsample(frac=frac) @@ -208,18 +206,13 @@ def add_takeup(self): aca_rate = load_take_up_rate("aca", self.time_period) medicaid_rates_by_state = load_take_up_rate("medicaid", self.time_period) head_start_rate = load_take_up_rate("head_start", self.time_period) - early_head_start_rate = load_take_up_rate( - "early_head_start", self.time_period - ) + early_head_start_rate = load_take_up_rate("early_head_start", self.time_period) ssi_rate = load_take_up_rate("ssi", self.time_period) # EITC: varies by number of children eitc_child_count = baseline.calculate("eitc_child_count").values eitc_takeup_rate = np.array( - [ - eitc_rates_by_children.get(min(int(c), 3), 0.85) - for c in eitc_child_count - ] + [eitc_rates_by_children.get(min(int(c), 3), 0.85) for c in eitc_child_count] ) rng = seeded_rng("takes_up_eitc") data["takes_up_eitc"] = rng.random(n_tax_units) < eitc_takeup_rate @@ -238,9 +231,7 @@ def add_takeup(self): target_snap_takeup_count = int(snap_rate * n_spm_units) remaining_snap_needed = max(0, target_snap_takeup_count - n_snap_reporters) snap_non_reporter_rate = ( - remaining_snap_needed / n_snap_non_reporters - if n_snap_non_reporters > 0 - else 0 + remaining_snap_needed / n_snap_non_reporters if n_snap_non_reporters > 0 else 0 ) # Assign: all reporters + adjusted rate for non-reporters @@ -257,9 +248,7 @@ def add_takeup(self): hh_ids = data["household_id"] person_hh_ids = data["person_household_id"] hh_to_state = dict(zip(hh_ids, state_codes)) - person_states = np.array( - [hh_to_state.get(hh_id, "CA") for hh_id in person_hh_ids] - ) + person_states = np.array([hh_to_state.get(hh_id, "CA") for hh_id in person_hh_ids]) medicaid_rate_by_person = np.array( [medicaid_rates_by_state.get(s, 0.93) for s in person_states] ) @@ -270,9 +259,7 @@ def add_takeup(self): # Head Start rng = seeded_rng("takes_up_head_start_if_eligible") - data["takes_up_head_start_if_eligible"] = ( - rng.random(n_persons) < head_start_rate - ) + data["takes_up_head_start_if_eligible"] = rng.random(n_persons) < head_start_rate # Early Head Start rng = seeded_rng("takes_up_early_head_start_if_eligible") @@ -290,9 +277,7 @@ def add_takeup(self): target_ssi_takeup_count = int(ssi_rate * n_persons) remaining_ssi_needed = max(0, target_ssi_takeup_count - n_ssi_reporters) ssi_non_reporter_rate = ( - remaining_ssi_needed / n_ssi_non_reporters - if n_ssi_non_reporters > 0 - else 0 + remaining_ssi_needed / n_ssi_non_reporters if n_ssi_non_reporters > 0 else 0 ) # Assign: all reporters + adjusted rate for non-reporters @@ -315,9 +300,7 @@ def add_takeup(self): data["would_claim_wic"] = rng.random(n_persons) < wic_takeup_rate_by_person # WIC nutritional risk — fully resolved - wic_risk_rates = load_take_up_rate( - "wic_nutritional_risk", self.time_period - ) + wic_risk_rates = load_take_up_rate("wic_nutritional_risk", self.time_period) wic_risk_rate_by_person = np.array( [wic_risk_rates.get(c, 0) for c in wic_categories] ) @@ -364,12 +347,8 @@ def uprate_cps_data(data, from_period, to_period): uprating = create_policyengine_uprating_factors_table() for variable in uprating.index.unique(): if variable in data: - current_index = uprating[uprating.index == variable][ - to_period - ].values[0] - start_index = uprating[uprating.index == variable][ - from_period - ].values[0] + current_index = uprating[uprating.index == variable][to_period].values[0] + start_index = uprating[uprating.index == variable][from_period].values[0] growth = current_index / start_index data[variable] = data[variable] * growth @@ -411,9 +390,7 @@ def add_id_variables( # Marital units - marital_unit_id = person.PH_SEQ * 1e6 + np.maximum( - person.A_LINENO, person.A_SPOUSE - ) + marital_unit_id = person.PH_SEQ * 1e6 + np.maximum(person.A_LINENO, person.A_SPOUSE) # marital_unit_id is not the household ID, zero padded and followed # by the index within household (of each person, or their spouse if @@ -453,9 +430,7 @@ def add_personal_variables(cps: h5py.File, person: DataFrame) -> None: # "Is...blind or does...have serious difficulty seeing even when Wearing # glasses?" 1 -> Yes cps["is_blind"] = person.PEDISEYE == 1 - DISABILITY_FLAGS = [ - "PEDIS" + i for i in ["DRS", "EAR", "EYE", "OUT", "PHY", "REM"] - ] + DISABILITY_FLAGS = ["PEDIS" + i for i in ["DRS", "EAR", "EYE", "OUT", "PHY", "REM"]] cps["is_disabled"] = (person[DISABILITY_FLAGS] == 1).any(axis=1) def children_per_parent(col: str) -> pd.DataFrame: @@ -477,9 +452,7 @@ def children_per_parent(col: str) -> pd.DataFrame: # Aggregate to parent. res = ( - pd.concat( - [children_per_parent("PEPAR1"), children_per_parent("PEPAR2")] - ) + pd.concat([children_per_parent("PEPAR1"), children_per_parent("PEPAR2")]) .groupby(["PH_SEQ", "A_LINENO"]) .children.sum() .reset_index() @@ -505,9 +478,7 @@ def children_per_parent(col: str) -> pd.DataFrame: add_overtime_occupation(cps, person) -def add_personal_income_variables( - cps: h5py.File, person: DataFrame, year: int -): +def add_personal_income_variables(cps: h5py.File, person: DataFrame, year: int): """Add income variables. Args: @@ -533,16 +504,14 @@ def add_personal_income_variables( cps["weekly_hours_worked"] = person.HRSWK cps["hours_worked_last_week"] = person.A_HRS1 - cps["taxable_interest_income"] = person.INT_VAL * ( - p["taxable_interest_fraction"] - ) + cps["taxable_interest_income"] = person.INT_VAL * (p["taxable_interest_fraction"]) cps["tax_exempt_interest_income"] = person.INT_VAL * ( 1 - p["taxable_interest_fraction"] ) cps["self_employment_income"] = person.SEMP_VAL cps["farm_income"] = person.FRSE_VAL - cps["qualified_dividend_income"] = person.DIV_VAL * ( - p["qualified_dividend_fraction"] + cps["qualified_dividend_income"] = ( + person.DIV_VAL * (p["qualified_dividend_fraction"]) ) cps["non_qualified_dividend_income"] = person.DIV_VAL * ( 1 - p["qualified_dividend_fraction"] @@ -561,18 +530,14 @@ def add_personal_income_variables( # 8 = Other is_retirement = (person.RESNSS1 == 1) | (person.RESNSS2 == 1) is_disability = (person.RESNSS1 == 2) | (person.RESNSS2 == 2) - is_survivor = np.isin(person.RESNSS1, [3, 5]) | np.isin( - person.RESNSS2, [3, 5] - ) + is_survivor = np.isin(person.RESNSS1, [3, 5]) | np.isin(person.RESNSS2, [3, 5]) is_dependent = np.isin(person.RESNSS1, [4, 6, 7]) | np.isin( person.RESNSS2, [4, 6, 7] ) # Primary classification: assign full SS_VAL to the highest- # priority category when someone has multiple source codes. - cps["social_security_retirement"] = np.where( - is_retirement, person.SS_VAL, 0 - ) + cps["social_security_retirement"] = np.where(is_retirement, person.SS_VAL, 0) cps["social_security_disability"] = np.where( is_disability & ~is_retirement, person.SS_VAL, 0 ) @@ -615,9 +580,7 @@ def add_personal_income_variables( # Add pensions and annuities. cps_pensions = person.PNSN_VAL + person.ANN_VAL # Assume a constant fraction of pension income is taxable. - cps["taxable_private_pension_income"] = ( - cps_pensions * p["taxable_pension_fraction"] - ) + cps["taxable_private_pension_income"] = cps_pensions * p["taxable_pension_fraction"] cps["tax_exempt_private_pension_income"] = cps_pensions * ( 1 - p["taxable_pension_fraction"] ) @@ -641,18 +604,11 @@ def add_personal_income_variables( for source_with_taxable_fraction in ["401k", "403b", "sep"]: cps[f"taxable_{source_with_taxable_fraction}_distributions"] = ( cps[f"{source_with_taxable_fraction}_distributions"] - * p[ - f"taxable_{source_with_taxable_fraction}_distribution_fraction" - ] + * p[f"taxable_{source_with_taxable_fraction}_distribution_fraction"] ) cps[f"tax_exempt_{source_with_taxable_fraction}_distributions"] = cps[ f"{source_with_taxable_fraction}_distributions" - ] * ( - 1 - - p[ - f"taxable_{source_with_taxable_fraction}_distribution_fraction" - ] - ) + ] * (1 - p[f"taxable_{source_with_taxable_fraction}_distribution_fraction"]) del cps[f"{source_with_taxable_fraction}_distributions"] # Assume all regular IRA distributions are taxable, @@ -740,9 +696,7 @@ def add_personal_income_variables( cps["traditional_ira_contributions"] = ira_capped * trad_ira_share cps["roth_ira_contributions"] = ira_capped * (1 - trad_ira_share) # Allocate capital gains into long-term and short-term based on aggregate split. - cps["long_term_capital_gains"] = person.CAP_VAL * ( - p["long_term_capgain_fraction"] - ) + cps["long_term_capital_gains"] = person.CAP_VAL * (p["long_term_capgain_fraction"]) cps["short_term_capital_gains"] = person.CAP_VAL * ( 1 - p["long_term_capgain_fraction"] ) @@ -770,10 +724,7 @@ def add_personal_income_variables( # Get QBI simulation parameters --- yamlfilename = ( - files("policyengine_us_data") - / "datasets" - / "puf" - / "qbi_assumptions.yaml" + files("policyengine_us_data") / "datasets" / "puf" / "qbi_assumptions.yaml" ) with open(yamlfilename, "r", encoding="utf-8") as yamlfile: p = yaml.safe_load(yamlfile) @@ -827,14 +778,10 @@ def add_spm_variables(self, cps: h5py.File, spm_unit: DataFrame) -> None: 3: "RENTER", } cps["spm_unit_tenure_type"] = ( - spm_unit.SPM_TENMORTSTATUS.map(tenure_map) - .fillna("RENTER") - .astype("S") + spm_unit.SPM_TENMORTSTATUS.map(tenure_map).fillna("RENTER").astype("S") ) - cps["reduced_price_school_meals_reported"] = ( - cps["free_school_meals_reported"] * 0 - ) + cps["reduced_price_school_meals_reported"] = cps["free_school_meals_reported"] * 0 def add_household_variables(cps: h5py.File, household: DataFrame) -> None: @@ -968,9 +915,7 @@ def select_random_subset_to_target( share_to_move = min(share_to_move, 1.0) # Cap at 100% else: # Calculate how much to move to reach target (for EAD case) - needed_weighted = ( - current_weighted - target_weighted - ) # Will be negative + needed_weighted = current_weighted - target_weighted # Will be negative total_weight = np.sum(person_weights[eligible_ids]) share_to_move = abs(needed_weighted) / total_weight share_to_move = min(share_to_move, 1.0) # Cap at 100% @@ -1214,9 +1159,7 @@ def select_random_subset_to_target( ) # CONDITION 10: Government Employees - is_government_worker = np.isin( - person.PEIO1COW, [1, 2, 3] - ) # Fed/state/local gov + is_government_worker = np.isin(person.PEIO1COW, [1, 2, 3]) # Fed/state/local gov is_military_occupation = person.A_MJOCC == 11 # Military occupation is_government_employee = is_government_worker | is_military_occupation condition_10_mask = potentially_undocumented & is_government_employee @@ -1330,12 +1273,8 @@ def select_random_subset_to_target( undocumented_students_mask = ( (ssn_card_type == 0) & noncitizens & (person.A_HSCOL == 2) ) - undocumented_workers_count = np.sum( - person_weights[undocumented_workers_mask] - ) - undocumented_students_count = np.sum( - person_weights[undocumented_students_mask] - ) + undocumented_workers_count = np.sum(person_weights[undocumented_workers_mask]) + undocumented_students_count = np.sum(person_weights[undocumented_students_mask]) after_conditions_code_0 = np.sum(person_weights[ssn_card_type == 0]) print(f"After conditions - Code 0 people: {after_conditions_code_0:,.0f}") @@ -1530,15 +1469,11 @@ def select_random_subset_to_target( f"Selected {len(selected_indices)} people from {len(mixed_household_candidates)} candidates in mixed households" ) else: - print( - "No additional family members selected (target already reached)" - ) + print("No additional family members selected (target already reached)") else: print("No mixed-status households found for family correlation") else: - print( - "No additional undocumented people needed - target already reached" - ) + print("No additional undocumented people needed - target already reached") # Calculate the weighted impact code_0_after = np.sum(person_weights[ssn_card_type == 0]) @@ -1613,9 +1548,7 @@ def get_arrival_year_midpoint(peinusyr): age_at_entry = np.maximum(0, person.A_AGE - years_in_us) # start every non-citizen as LPR so no UNSET survives - immigration_status = np.full( - len(person), "LEGAL_PERMANENT_RESIDENT", dtype="U32" - ) + immigration_status = np.full(len(person), "LEGAL_PERMANENT_RESIDENT", dtype="U32") # Set citizens (SSN card type 1) to CITIZEN status immigration_status[ssn_card_type == 1] = "CITIZEN" @@ -1663,9 +1596,7 @@ def get_arrival_year_midpoint(peinusyr): immigration_status[recent_refugee_mask] = "REFUGEE" # 6. Temp non-qualified (Code 2 not caught by DACA rule) - mask = (ssn_card_type == 2) & ( - immigration_status == "LEGAL_PERMANENT_RESIDENT" - ) + mask = (ssn_card_type == 2) & (immigration_status == "LEGAL_PERMANENT_RESIDENT") immigration_status[mask] = "TPS" # Final write (all values now in ImmigrationStatus Enum) @@ -1681,9 +1612,7 @@ def get_arrival_year_midpoint(peinusyr): 2: "NON_CITIZEN_VALID_EAD", # Non-citizens with work/study authorization 3: "OTHER_NON_CITIZEN", # Non-citizens with indicators of legal status } - ssn_card_type_str = ( - pd.Series(ssn_card_type).map(code_to_str).astype("S").values - ) + ssn_card_type_str = pd.Series(ssn_card_type).map(code_to_str).astype("S").values cps["ssn_card_type"] = ssn_card_type_str # Final population summary @@ -1890,9 +1819,7 @@ def add_tips(self, cps: h5py.File): # Drop temporary columns used only for imputation # is_married is person-level here but policyengine-us defines it at Family # level, so we must not save it - cps = cps.drop( - columns=["is_married", "is_under_18", "is_under_6"], errors="ignore" - ) + cps = cps.drop(columns=["is_married", "is_under_18", "is_under_6"], errors="ignore") self.save_dataset(cps) @@ -2012,9 +1939,7 @@ def create_scf_reference_person_mask(cps_data, raw_person_data): all_persons_data["is_female"] = (raw_person_data.A_SEX == 2).values # Add marital status (A_MARITL codes: 1,2 = married with spouse present/absent) - all_persons_data["is_married"] = raw_person_data.A_MARITL.isin( - [1, 2] - ).values + all_persons_data["is_married"] = raw_person_data.A_MARITL.isin([1, 2]).values # Define adults as age 18+ all_persons_data["is_adult"] = all_persons_data["age"] >= 18 @@ -2033,8 +1958,7 @@ def create_scf_reference_person_mask(cps_data, raw_person_data): # Identify couple households (households with exactly 2 married adults) married_adults_per_household = ( all_persons_data[ - (all_persons_data["is_adult"]) - & (all_persons_data["is_married"]) + (all_persons_data["is_adult"]) & (all_persons_data["is_married"]) ] .groupby("person_household_id") .size() @@ -2042,12 +1966,7 @@ def create_scf_reference_person_mask(cps_data, raw_person_data): couple_households = married_adults_per_household[ (married_adults_per_household == 2) - & ( - all_persons_data.groupby("person_household_id")[ - "n_adults" - ].first() - == 2 - ) + & (all_persons_data.groupby("person_household_id")["n_adults"].first() == 2) ].index all_persons_data["is_couple_household"] = all_persons_data[ @@ -2147,9 +2066,7 @@ def determine_reference_person(group): } # Apply the mapping to recode the race values - cps_data["cps_race"] = np.vectorize(CPS_RACE_MAPPING.get)( - cps_data["cps_race"] - ) + cps_data["cps_race"] = np.vectorize(CPS_RACE_MAPPING.get)(cps_data["cps_race"]) lengths = {k: len(v) for k, v in cps_data.items()} var_len = cps_data["person_household_id"].shape[0] @@ -2181,9 +2098,7 @@ def determine_reference_person(group): # Add is_married variable for household heads based on raw person data reference_persons = person_data[mask] - receiver_data["is_married"] = reference_persons.A_MARITL.isin( - [1, 2] - ).values + receiver_data["is_married"] = reference_persons.A_MARITL.isin([1, 2]).values # Impute auto loan balance from the SCF from policyengine_us_data.datasets.scf.scf import SCF_2022 @@ -2218,9 +2133,7 @@ def determine_reference_person(group): logging.getLogger("microimpute").setLevel(getattr(logging, log_level)) qrf_model = QRF() - donor_data = donor_data.sample(frac=0.5, random_state=42).reset_index( - drop=True - ) + donor_data = donor_data.sample(frac=0.5, random_state=42).reset_index(drop=True) fitted_model = qrf_model.fit( X_train=donor_data, predictors=PREDICTORS, diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 2b3b46ef..8755c73e 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -46,9 +46,7 @@ def reweight( normalisation_factor = np.where( is_national, nation_normalisation_factor, state_normalisation_factor ) - normalisation_factor = torch.tensor( - normalisation_factor, dtype=torch.float32 - ) + normalisation_factor = torch.tensor(normalisation_factor, dtype=torch.float32) targets_array = torch.tensor(targets_array, dtype=torch.float32) inv_mean_normalisation = 1 / np.mean(normalisation_factor.numpy()) @@ -61,12 +59,8 @@ def loss(weights): estimate = weights @ loss_matrix if torch.isnan(estimate).any(): raise ValueError("Estimate contains NaNs") - rel_error = ( - ((estimate - targets_array) + 1) / (targets_array + 1) - ) ** 2 - rel_error_normalized = ( - inv_mean_normalisation * rel_error * normalisation_factor - ) + rel_error = (((estimate - targets_array) + 1) / (targets_array + 1)) ** 2 + rel_error_normalized = inv_mean_normalisation * rel_error * normalisation_factor if torch.isnan(rel_error_normalized).any(): raise ValueError("Relative error contains NaNs") return rel_error_normalized.mean() @@ -121,9 +115,7 @@ def loss(weights): start_loss = l.item() loss_rel_change = (l.item() - start_loss) / start_loss l.backward() - iterator.set_postfix( - {"loss": l.item(), "loss_rel_change": loss_rel_change} - ) + iterator.set_postfix({"loss": l.item(), "loss_rel_change": loss_rel_change}) optimizer.step() if log_path is not None: performance.to_csv(log_path, index=False) @@ -182,9 +174,7 @@ def generate(self): # Run the optimization procedure to get (close to) minimum loss weights for year in range(self.start_year, self.end_year + 1): - loss_matrix, targets_array = build_loss_matrix( - self.input_dataset, year - ) + loss_matrix, targets_array = build_loss_matrix(self.input_dataset, year) zero_mask = np.isclose(targets_array, 0.0, atol=0.1) bad_mask = loss_matrix.columns.isin(bad_targets) keep_mask_bool = ~(zero_mask | bad_mask) @@ -210,9 +200,7 @@ def generate(self): # Validate dense weights w = optimised_weights if np.any(np.isnan(w)): - raise ValueError( - f"Year {year}: household_weight contains NaN values" - ) + raise ValueError(f"Year {year}: household_weight contains NaN values") if np.any(w < 0): raise ValueError( f"Year {year}: household_weight contains negative values" @@ -253,12 +241,8 @@ def generate(self): 1, 0.1, len(original_weights) ) for year in [2024]: - loss_matrix, targets_array = build_loss_matrix( - self.input_dataset, year - ) - optimised_weights = reweight( - original_weights, loss_matrix, targets_array - ) + loss_matrix, targets_array = build_loss_matrix(self.input_dataset, year) + optimised_weights = reweight(original_weights, loss_matrix, targets_array) data["household_weight"] = optimised_weights self.save_dataset(data) diff --git a/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py b/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py index 28bdfd3e..5fe3e599 100644 --- a/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py +++ b/policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py @@ -24,8 +24,7 @@ ## Taxable Payroll for Social Security taxible_estimate_b = ( sim.calculate("taxable_earnings_for_social_security").sum() / 1e9 - + sim.calculate("social_security_taxable_self_employment_income").sum() - / 1e9 + + sim.calculate("social_security_taxable_self_employment_income").sum() / 1e9 ) ### Trustees SingleYearTRTables_TR2025.xlsx, Tab VI.G6 (nominal dollars in billions) @@ -66,8 +65,7 @@ ## Taxable Payroll for Social Security taxible_estimate_b = ( sim.calculate("taxable_earnings_for_social_security").sum() / 1e9 - + sim.calculate("social_security_taxable_self_employment_income").sum() - / 1e9 + + sim.calculate("social_security_taxable_self_employment_income").sum() / 1e9 ) ### Trustees SingleYearTRTables_TR2025.xlsx, Tab VI.G6 (nominal dollars in billions) @@ -175,9 +173,9 @@ def create_h6_reform(): # The swapped rate error is 14x smaller and aligns with tax-cutting intent. # Tier 1 (Base): HI ONLY (35%) - reform_payload[ - "gov.irs.social_security.taxability.rate.base.benefit_cap" - ][period] = 0.35 + reform_payload["gov.irs.social_security.taxability.rate.base.benefit_cap"][ + period + ] = 0.35 reform_payload["gov.irs.social_security.taxability.rate.base.excess"][ period ] = 0.35 @@ -186,25 +184,25 @@ def create_h6_reform(): reform_payload[ "gov.irs.social_security.taxability.rate.additional.benefit_cap" ][period] = 0.85 - reform_payload[ - "gov.irs.social_security.taxability.rate.additional.excess" - ][period] = 0.85 + reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ + period + ] = 0.85 # --- SET THRESHOLDS (MIN/MAX SWAP) --- # Always put the smaller number in 'base' and larger in 'adjusted_base' # Single - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.SINGLE" - ][period] = min(oasdi_target_single, HI_SINGLE) + reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ + period + ] = min(oasdi_target_single, HI_SINGLE) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.SINGLE" ][period] = max(oasdi_target_single, HI_SINGLE) # Joint - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.JOINT" - ][period] = min(oasdi_target_joint, HI_JOINT) + reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ + period + ] = min(oasdi_target_joint, HI_JOINT) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.JOINT" ][period] = max(oasdi_target_joint, HI_JOINT) @@ -228,12 +226,12 @@ def create_h6_reform(): # 1. Set Thresholds to "HI Only" mode # Base = $34k / $44k - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.SINGLE" - ][elim_period] = HI_SINGLE - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.JOINT" - ][elim_period] = HI_JOINT + reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ + elim_period + ] = HI_SINGLE + reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ + elim_period + ] = HI_JOINT # Adjusted = Infinity (Disable the second tier effectively) reform_payload[ @@ -262,12 +260,12 @@ def create_h6_reform(): ] = 0.35 # Tier 2 (Disabled via threshold, but zero out for safety) - reform_payload[ - "gov.irs.social_security.taxability.rate.additional.benefit_cap" - ][elim_period] = 0.35 - reform_payload[ - "gov.irs.social_security.taxability.rate.additional.excess" - ][elim_period] = 0.35 + reform_payload["gov.irs.social_security.taxability.rate.additional.benefit_cap"][ + elim_period + ] = 0.35 + reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ + elim_period + ] = 0.35 return reform_payload @@ -298,23 +296,17 @@ def create_h6_reform(): print(f"revenue_impact (B): {revenue_impact / 1e9:.2f}") # Calculate taxable payroll -taxable_ss_earnings = baseline.calculate( - "taxable_earnings_for_social_security" -) +taxable_ss_earnings = baseline.calculate("taxable_earnings_for_social_security") taxable_self_employment = baseline.calculate( "social_security_taxable_self_employment_income" ) -total_taxable_payroll = ( - taxable_ss_earnings.sum() + taxable_self_employment.sum() -) +total_taxable_payroll = taxable_ss_earnings.sum() + taxable_self_employment.sum() # Calculate SS benefits ss_benefits = baseline.calculate("social_security") total_ss_benefits = ss_benefits.sum() -est_rev_as_pct_of_taxable_payroll = ( - 100 * revenue_impact / total_taxable_payroll -) +est_rev_as_pct_of_taxable_payroll = 100 * revenue_impact / total_taxable_payroll # From https://www.ssa.gov/oact/solvency/provisions/tables/table_run133.html: target_rev_as_pct_of_taxable_payroll = -1.12 diff --git a/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py b/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py index 5ada2db9..492a9d69 100644 --- a/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py +++ b/policyengine_us_data/datasets/cps/long_term/extract_ssa_costs.py @@ -2,9 +2,7 @@ import numpy as np # Read the file -df = pd.read_excel( - "SingleYearTRTables_TR2025.xlsx", sheet_name="VI.G9", header=None -) +df = pd.read_excel("SingleYearTRTables_TR2025.xlsx", sheet_name="VI.G9", header=None) print("DataFrame shape:", df.shape) print("\nChecking data types around row 66-70:") diff --git a/policyengine_us_data/datasets/cps/long_term/projection_utils.py b/policyengine_us_data/datasets/cps/long_term/projection_utils.py index d0af8533..8aee4f3b 100644 --- a/policyengine_us_data/datasets/cps/long_term/projection_utils.py +++ b/policyengine_us_data/datasets/cps/long_term/projection_utils.py @@ -27,9 +27,7 @@ def build_household_age_matrix(sim, n_ages=86): n_households = len(household_ids_unique) X = np.zeros((n_households, n_ages)) - hh_id_to_idx = { - hh_id: idx for idx, hh_id in enumerate(household_ids_unique) - } + hh_id_to_idx = {hh_id: idx for idx, hh_id in enumerate(household_ids_unique)} for person_idx in range(len(age_person)): age = int(age_person.values[person_idx]) @@ -67,9 +65,7 @@ def get_pseudo_input_variables(sim): return pseudo_inputs -def create_household_year_h5( - year, household_weights, base_dataset_path, output_dir -): +def create_household_year_h5(year, household_weights, base_dataset_path, output_dir): """ Create a year-specific .h5 file with calibrated household weights. @@ -193,9 +189,7 @@ def calculate_year_statistics( Returns: Dictionary with year statistics and calibrated weights """ - income_tax_hh = sim.calculate( - "income_tax", period=year, map_to="household" - ) + income_tax_hh = sim.calculate("income_tax", period=year, map_to="household") income_tax_baseline_total = income_tax_hh.sum() income_tax_values = income_tax_hh.values @@ -206,9 +200,7 @@ def calculate_year_statistics( ss_values = None ss_target = None if use_ss: - ss_hh = sim.calculate( - "social_security", period=year, map_to="household" - ) + ss_hh = sim.calculate("social_security", period=year, map_to="household") ss_baseline_total = ss_hh.sum() ss_values = ss_hh.values diff --git a/policyengine_us_data/datasets/cps/long_term/run_household_projection.py b/policyengine_us_data/datasets/cps/long_term/run_household_projection.py index 30d1857a..1413efe4 100644 --- a/policyengine_us_data/datasets/cps/long_term/run_household_projection.py +++ b/policyengine_us_data/datasets/cps/long_term/run_household_projection.py @@ -105,9 +105,9 @@ def create_h6_reform(): # The swapped rate error is 14x smaller and aligns with tax-cutting intent. # Tier 1 (Base): HI ONLY (35%) - reform_payload[ - "gov.irs.social_security.taxability.rate.base.benefit_cap" - ][period] = 0.35 + reform_payload["gov.irs.social_security.taxability.rate.base.benefit_cap"][ + period + ] = 0.35 reform_payload["gov.irs.social_security.taxability.rate.base.excess"][ period ] = 0.35 @@ -116,25 +116,25 @@ def create_h6_reform(): reform_payload[ "gov.irs.social_security.taxability.rate.additional.benefit_cap" ][period] = 0.85 - reform_payload[ - "gov.irs.social_security.taxability.rate.additional.excess" - ][period] = 0.85 + reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ + period + ] = 0.85 # --- SET THRESHOLDS (MIN/MAX SWAP) --- # Always put the smaller number in 'base' and larger in 'adjusted_base' # Single - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.SINGLE" - ][period] = min(oasdi_target_single, HI_SINGLE) + reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ + period + ] = min(oasdi_target_single, HI_SINGLE) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.SINGLE" ][period] = max(oasdi_target_single, HI_SINGLE) # Joint - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.JOINT" - ][period] = min(oasdi_target_joint, HI_JOINT) + reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ + period + ] = min(oasdi_target_joint, HI_JOINT) reform_payload[ "gov.irs.social_security.taxability.threshold.adjusted_base.main.JOINT" ][period] = max(oasdi_target_joint, HI_JOINT) @@ -158,12 +158,12 @@ def create_h6_reform(): # 1. Set Thresholds to "HI Only" mode # Base = $34k / $44k - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.SINGLE" - ][elim_period] = HI_SINGLE - reform_payload[ - "gov.irs.social_security.taxability.threshold.base.main.JOINT" - ][elim_period] = HI_JOINT + reform_payload["gov.irs.social_security.taxability.threshold.base.main.SINGLE"][ + elim_period + ] = HI_SINGLE + reform_payload["gov.irs.social_security.taxability.threshold.base.main.JOINT"][ + elim_period + ] = HI_JOINT # Adjusted = Infinity (Disable the second tier effectively) reform_payload[ @@ -192,12 +192,12 @@ def create_h6_reform(): ] = 0.35 # Tier 2 (Disabled via threshold, but zero out for safety) - reform_payload[ - "gov.irs.social_security.taxability.rate.additional.benefit_cap" - ][elim_period] = 0.35 - reform_payload[ - "gov.irs.social_security.taxability.rate.additional.excess" - ][elim_period] = 0.35 + reform_payload["gov.irs.social_security.taxability.rate.additional.benefit_cap"][ + elim_period + ] = 0.35 + reform_payload["gov.irs.social_security.taxability.rate.additional.excess"][ + elim_period + ] = 0.35 # Create the Reform Object from policyengine_core.reforms import Reform @@ -242,18 +242,14 @@ def create_h6_reform(): if USE_PAYROLL: sys.argv.remove("--use-payroll") if not USE_GREG: - print( - "Warning: --use-payroll requires --greg, enabling GREG automatically" - ) + print("Warning: --use-payroll requires --greg, enabling GREG automatically") USE_GREG = True USE_H6_REFORM = "--use-h6-reform" in sys.argv if USE_H6_REFORM: sys.argv.remove("--use-h6-reform") if not USE_GREG: - print( - "Warning: --use-h6-reform requires --greg, enabling GREG automatically" - ) + print("Warning: --use-h6-reform requires --greg, enabling GREG automatically") USE_GREG = True from ssa_data import load_h6_income_rate_change @@ -261,9 +257,7 @@ def create_h6_reform(): if USE_TOB: sys.argv.remove("--use-tob") if not USE_GREG: - print( - "Warning: --use-tob requires --greg, enabling GREG automatically" - ) + print("Warning: --use-tob requires --greg, enabling GREG automatically") USE_GREG = True from ssa_data import load_oasdi_tob_projections, load_hi_tob_projections @@ -320,9 +314,7 @@ def create_h6_reform(): print("STEP 1: DEMOGRAPHIC PROJECTIONS") print("=" * 70) -target_matrix = load_ssa_age_projections( - start_year=START_YEAR, end_year=END_YEAR -) +target_matrix = load_ssa_age_projections(start_year=START_YEAR, end_year=END_YEAR) n_years = target_matrix.shape[1] n_ages = target_matrix.shape[0] @@ -390,9 +382,7 @@ def create_h6_reform(): sim = Microsimulation(dataset=BASE_DATASET_PATH) - income_tax_hh = sim.calculate( - "income_tax", period=year, map_to="household" - ) + income_tax_hh = sim.calculate("income_tax", period=year, map_to="household") income_tax_baseline_total = income_tax_hh.sum() income_tax_values = income_tax_hh.values @@ -405,9 +395,7 @@ def create_h6_reform(): ss_values = None ss_target = None if USE_SS: - ss_hh = sim.calculate( - "social_security", period=year, map_to="household" - ) + ss_hh = sim.calculate("social_security", period=year, map_to="household") ss_values = ss_hh.values ss_target = load_ssa_benefit_projections(year) if year in display_years: @@ -452,9 +440,7 @@ def create_h6_reform(): else: # Create and apply H6 reform h6_reform = create_h6_reform() - reform_sim = Microsimulation( - dataset=BASE_DATASET_PATH, reform=h6_reform - ) + reform_sim = Microsimulation(dataset=BASE_DATASET_PATH, reform=h6_reform) # Calculate reform income tax income_tax_reform_hh = reform_sim.calculate( @@ -472,9 +458,7 @@ def create_h6_reform(): # Debug output for key years if year in display_years: - h6_impact_baseline = np.sum( - h6_income_values * baseline_weights - ) + h6_impact_baseline = np.sum(h6_income_values * baseline_weights) print( f" [DEBUG {year}] H6 baseline revenue: ${h6_impact_baseline / 1e9:.3f}B, target: ${h6_revenue_target / 1e9:.3f}B" ) @@ -547,13 +531,9 @@ def create_h6_reform(): f"largest: {max_neg:,.0f}" ) else: - print( - f" [DEBUG {year}] Negative weights: 0 (all weights non-negative)" - ) + print(f" [DEBUG {year}] Negative weights: 0 (all weights non-negative)") - if year in display_years and ( - USE_SS or USE_PAYROLL or USE_H6_REFORM or USE_TOB - ): + if year in display_years and (USE_SS or USE_PAYROLL or USE_H6_REFORM or USE_TOB): if USE_SS: ss_achieved = np.sum(ss_values * w_new) print( @@ -567,9 +547,7 @@ def create_h6_reform(): if USE_H6_REFORM and h6_revenue_target is not None: h6_revenue_achieved = np.sum(h6_income_values * w_new) error_pct = ( - (h6_revenue_achieved - h6_revenue_target) - / abs(h6_revenue_target) - * 100 + (h6_revenue_achieved - h6_revenue_target) / abs(h6_revenue_target) * 100 if h6_revenue_target != 0 else 0 ) @@ -593,9 +571,7 @@ def create_h6_reform(): total_population[year_idx] = np.sum(y_target) if SAVE_H5: - h5_path = create_household_year_h5( - year, w_new, BASE_DATASET_PATH, OUTPUT_DIR - ) + h5_path = create_household_year_h5(year, w_new, BASE_DATASET_PATH, OUTPUT_DIR) if year in display_years: print(f" Saved {year}.h5") diff --git a/policyengine_us_data/datasets/cps/small_enhanced_cps.py b/policyengine_us_data/datasets/cps/small_enhanced_cps.py index 53607d03..a1508032 100644 --- a/policyengine_us_data/datasets/cps/small_enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/small_enhanced_cps.py @@ -35,9 +35,7 @@ def create_small_ecps(): data[variable] = {} for time_period in simulation.get_holder(variable).get_known_periods(): values = simulation.get_holder(variable).get_array(time_period) - if simulation.tax_benefit_system.variables.get( - variable - ).value_type in ( + if simulation.tax_benefit_system.variables.get(variable).value_type in ( Enum, str, ): @@ -114,8 +112,7 @@ def create_sparse_ecps(): for time_period in sim.get_holder(variable).get_known_periods(): values = sim.get_holder(variable).get_array(time_period) if ( - sim.tax_benefit_system.variables.get(variable).value_type - in (Enum, str) + sim.tax_benefit_system.variables.get(variable).value_type in (Enum, str) and variable != "county_fips" ): values = values.decode_to_str().astype("S") @@ -138,9 +135,7 @@ def create_sparse_ecps(): ] missing = [v for v in critical_vars if v not in data] if missing: - raise ValueError( - f"create_sparse_ecps: missing critical variables: {missing}" - ) + raise ValueError(f"create_sparse_ecps: missing critical variables: {missing}") logging.info(f"create_sparse_ecps: data dict has {len(data)} variables") output_path = STORAGE_FOLDER / "sparse_enhanced_cps_2024.h5" @@ -155,9 +150,7 @@ def create_sparse_ecps(): raise ValueError( f"create_sparse_ecps: output file only {file_size:,} bytes (expected > 1MB)" ) - logging.info( - f"create_sparse_ecps: wrote {file_size / 1e6:.1f}MB to {output_path}" - ) + logging.info(f"create_sparse_ecps: wrote {file_size / 1e6:.1f}MB to {output_path}") if __name__ == "__main__": diff --git a/policyengine_us_data/datasets/puf/irs_puf.py b/policyengine_us_data/datasets/puf/irs_puf.py index dd77890a..c357cd56 100644 --- a/policyengine_us_data/datasets/puf/irs_puf.py +++ b/policyengine_us_data/datasets/puf/irs_puf.py @@ -30,9 +30,7 @@ def generate(self): with pd.HDFStore(self.file_path, mode="w") as storage: storage.put("puf", pd.read_csv(puf_file_path)) - storage.put( - "puf_demographics", pd.read_csv(puf_demographics_file_path) - ) + storage.put("puf_demographics", pd.read_csv(puf_demographics_file_path)) class IRS_PUF_2015(IRS_PUF): diff --git a/policyengine_us_data/datasets/puf/puf.py b/policyengine_us_data/datasets/puf/puf.py index ae8cf4fe..040098c1 100644 --- a/policyengine_us_data/datasets/puf/puf.py +++ b/policyengine_us_data/datasets/puf/puf.py @@ -109,14 +109,10 @@ def simulate_w2_and_ubia_from_puf(puf, *, seed=None, diagnostics=True): ) revenues = np.maximum(qbi, 0) / margins - logit = ( - logit_params["intercept"] + logit_params["slope_per_dollar"] * revenues - ) + logit = logit_params["intercept"] + logit_params["slope_per_dollar"] * revenues # Set p = 0 when simulated receipts == 0 (no revenue means no payroll) - pr_has_employees = np.where( - revenues == 0.0, 0.0, 1.0 / (1.0 + np.exp(-logit)) - ) + pr_has_employees = np.where(revenues == 0.0, 0.0, 1.0 / (1.0 + np.exp(-logit))) has_employees = rng.binomial(1, pr_has_employees) # Labor share simulation @@ -125,8 +121,7 @@ def simulate_w2_and_ubia_from_puf(puf, *, seed=None, diagnostics=True): labor_ratios = np.where( is_rental, rng.beta(rental_beta_a, rental_beta_b, qbi.size) * rental_scale, - rng.beta(non_rental_beta_a, non_rental_beta_b, qbi.size) - * non_rental_scale, + rng.beta(non_rental_beta_a, non_rental_beta_b, qbi.size) * non_rental_scale, ) w2_wages = revenues * labor_ratios * has_employees @@ -209,9 +204,7 @@ def impute_missing_demographics( .fillna(0) ) - puf_with_demographics = puf_with_demographics.sample( - n=10_000, random_state=0 - ) + puf_with_demographics = puf_with_demographics.sample(n=10_000, random_state=0) DEMOGRAPHIC_VARIABLES = [ "AGEDP1", @@ -411,9 +404,7 @@ def preprocess_puf(puf: pd.DataFrame) -> pd.DataFrame: - puf["E25920"].fillna(0) - puf["E25960"].fillna(0) ) != 0 - partnership_se = np.where( - has_partnership, gross_se - schedule_c_f_income, 0 - ) + partnership_se = np.where(has_partnership, gross_se - schedule_c_f_income, 0) puf["partnership_se_income"] = partnership_se # --- Qualified Business Income Deduction (QBID) simulation --- @@ -424,9 +415,9 @@ def preprocess_puf(puf: pd.DataFrame) -> pd.DataFrame: puf_qbi_sources_for_sstb = puf[QBI_PARAMS["sstb_prob_map_by_name"].keys()] largest_qbi_source_name = puf_qbi_sources_for_sstb.idxmax(axis=1) - pr_sstb = largest_qbi_source_name.map( - QBI_PARAMS["sstb_prob_map_by_name"] - ).fillna(0.0) + pr_sstb = largest_qbi_source_name.map(QBI_PARAMS["sstb_prob_map_by_name"]).fillna( + 0.0 + ) puf["business_is_sstb"] = np.random.binomial(n=1, p=pr_sstb) reit_params = QBI_PARAMS["reit_ptp_income_distribution"] @@ -553,9 +544,9 @@ def generate(self): current_index = uprating[uprating.Variable == variable][ self.time_period ].values[0] - start_index = uprating[uprating.Variable == variable][ - 2021 - ].values[0] + start_index = uprating[uprating.Variable == variable][2021].values[ + 0 + ] growth = current_index / start_index arrays[variable] = arrays[variable] * growth self.save_dataset(arrays) @@ -635,9 +626,7 @@ def generate(self): for group in groups_assumed_to_be_tax_unit_like: self.holder[f"{group}_id"] = self.holder["tax_unit_id"] - self.holder[f"person_{group}_id"] = self.holder[ - "person_tax_unit_id" - ] + self.holder[f"person_{group}_id"] = self.holder["person_tax_unit_id"] for key in self.holder: if key == "filing_status": @@ -689,9 +678,7 @@ def add_filer(self, row, tax_unit_id): # Assume all of the interest deduction is the filer's deductible mortgage interest - self.holder["deductible_mortgage_interest"].append( - row["interest_deduction"] - ) + self.holder["deductible_mortgage_interest"].append(row["interest_deduction"]) for key in self.available_financial_vars: if key == "deductible_mortgage_interest": diff --git a/policyengine_us_data/datasets/scf/fed_scf.py b/policyengine_us_data/datasets/scf/fed_scf.py index f67a2c07..8c0d8e8c 100644 --- a/policyengine_us_data/datasets/scf/fed_scf.py +++ b/policyengine_us_data/datasets/scf/fed_scf.py @@ -32,16 +32,12 @@ def load(self): def generate(self): if self._scf_download_url is None: - raise ValueError( - f"No raw SCF data URL known for year {self.time_period}." - ) + raise ValueError(f"No raw SCF data URL known for year {self.time_period}.") url = self._scf_download_url response = requests.get(url, stream=True) - total_size_in_bytes = int( - response.headers.get("content-length", 200e6) - ) + total_size_in_bytes = int(response.headers.get("content-length", 200e6)) progress_bar = tqdm( total=total_size_in_bytes, unit="iB", @@ -49,9 +45,7 @@ def generate(self): desc="Downloading SCF", ) if response.status_code == 404: - raise FileNotFoundError( - "Received a 404 response when fetching the data." - ) + raise FileNotFoundError("Received a 404 response when fetching the data.") with BytesIO() as file: content_length_actual = 0 for data in response.iter_content(int(1e6)): @@ -65,9 +59,7 @@ def generate(self): zipfile = ZipFile(file) with pd.HDFStore(self.file_path, mode="w") as storage: # Find the Stata file, which should be the only .dta file in the zip - dta_files = [ - f for f in zipfile.namelist() if f.endswith(".dta") - ] + dta_files = [f for f in zipfile.namelist() if f.endswith(".dta")] if not dta_files: raise FileNotFoundError( "No .dta file found in the SCF zip archive." diff --git a/policyengine_us_data/datasets/scf/scf.py b/policyengine_us_data/datasets/scf/scf.py index 1567fbbb..3f2f11a7 100644 --- a/policyengine_us_data/datasets/scf/scf.py +++ b/policyengine_us_data/datasets/scf/scf.py @@ -55,9 +55,7 @@ def generate(self): try: scf[key] = np.array(scf[key]) except Exception as e: - print( - f"Warning: Could not convert {key} to numpy array: {e}" - ) + print(f"Warning: Could not convert {key} to numpy array: {e}") self.save_dataset(scf) @@ -110,9 +108,7 @@ def downsample(self, frac: float): # Store original dtypes before modifying original_data: dict = self.load_dataset() - original_dtypes = { - key: original_data[key].dtype for key in original_data - } + original_dtypes = {key: original_data[key].dtype for key in original_data} sim = Microsimulation(dataset=self) sim.subsample(frac=frac) @@ -189,17 +185,13 @@ def rename_columns_to_match_cps(scf: dict, raw_data: pd.DataFrame) -> None: 4: 4, # Asian 5: 7, # Other } - scf["cps_race"] = ( - raw_data["racecl5"].map(race_map).fillna(6).astype(int).values - ) + scf["cps_race"] = raw_data["racecl5"].map(race_map).fillna(6).astype(int).values # Hispanic indicator scf["is_hispanic"] = (raw_data["racecl5"] == 3).values # Children in household if "kids" in raw_data.columns: - scf["own_children_in_household"] = ( - raw_data["kids"].fillna(0).astype(int).values - ) + scf["own_children_in_household"] = raw_data["kids"].fillna(0).astype(int).values # Rent if "rent" in raw_data.columns: @@ -207,9 +199,7 @@ def rename_columns_to_match_cps(scf: dict, raw_data: pd.DataFrame) -> None: # Vehicle loan (auto loan) if "veh_inst" in raw_data.columns: - scf["total_vehicle_installments"] = ( - raw_data["veh_inst"].fillna(0).values - ) + scf["total_vehicle_installments"] = raw_data["veh_inst"].fillna(0).values # Marital status if "married" in raw_data.columns: @@ -269,9 +259,7 @@ def add_auto_loan_interest(scf: dict, year: int) -> None: logger.error( f"Network error downloading SCF data for year {year}: {str(e)}" ) - raise RuntimeError( - f"Failed to download SCF data for year {year}" - ) from e + raise RuntimeError(f"Failed to download SCF data for year {year}") from e # Process zip file try: @@ -282,9 +270,7 @@ def add_auto_loan_interest(scf: dict, year: int) -> None: dta_files = [f for f in z.namelist() if f.endswith(".dta")] if not dta_files: logger.error(f"No Stata files found in zip for year {year}") - raise ValueError( - f"No Stata files found in zip for year {year}" - ) + raise ValueError(f"No Stata files found in zip for year {year}") logger.info(f"Found Stata files: {dta_files}") @@ -298,18 +284,14 @@ def add_auto_loan_interest(scf: dict, year: int) -> None: ) logger.info(f"Read DataFrame with shape {df.shape}") except Exception as e: - logger.error( - f"Error reading Stata file for year {year}: {str(e)}" - ) + logger.error(f"Error reading Stata file for year {year}: {str(e)}") raise RuntimeError( f"Failed to process Stata file for year {year}" ) from e except zipfile.BadZipFile as e: logger.error(f"Bad zip file for year {year}: {str(e)}") - raise RuntimeError( - f"Downloaded zip file is corrupt for year {year}" - ) from e + raise RuntimeError(f"Downloaded zip file is corrupt for year {year}") from e # Process the interest data and add to final SCF dictionary auto_df = df[IDENTIFYER_COLUMNS + AUTO_LOAN_COLUMNS].copy() diff --git a/policyengine_us_data/datasets/sipp/sipp.py b/policyengine_us_data/datasets/sipp/sipp.py index bf8b75dd..d7708266 100644 --- a/policyengine_us_data/datasets/sipp/sipp.py +++ b/policyengine_us_data/datasets/sipp/sipp.py @@ -68,8 +68,7 @@ def train_tip_model(): ) # Sum tip columns (AJB*_TXAMT + TJB*_TXAMT) across all jobs. df["tip_income"] = ( - df[df.columns[df.columns.str.contains("TXAMT")]].fillna(0).sum(axis=1) - * 12 + df[df.columns[df.columns.str.contains("TXAMT")]].fillna(0).sum(axis=1) * 12 ) df["employment_income"] = df.TPTOTINC * 12 df["is_under_18"] = (df.TAGE < 18) & (df.MONTHCODE == 12) diff --git a/policyengine_us_data/db/create_database_tables.py b/policyengine_us_data/db/create_database_tables.py index be22fcbb..d89bad31 100644 --- a/policyengine_us_data/db/create_database_tables.py +++ b/policyengine_us_data/db/create_database_tables.py @@ -39,9 +39,7 @@ class Stratum(SQLModel, table=True): description="Unique identifier for the stratum.", ) definition_hash: str = Field( - sa_column_kwargs={ - "comment": "SHA-256 hash of the stratum's constraints." - }, + sa_column_kwargs={"comment": "SHA-256 hash of the stratum's constraints."}, max_length=64, ) parent_stratum_id: Optional[int] = Field( @@ -89,9 +87,7 @@ class StratumConstraint(SQLModel, table=True): primary_key=True, description="The comparison operator (==, !=, >, >=, <, <=).", ) - value: str = Field( - description="The value for the constraint rule (e.g., '25')." - ) + value: str = Field(description="The value for the constraint rule (e.g., '25').") notes: Optional[str] = Field( default=None, description="Optional notes about the constraint." ) @@ -117,9 +113,7 @@ class Target(SQLModel, table=True): variable: str = Field( description="A variable defined in policyengine-us (e.g., 'income_tax')." ) - period: int = Field( - description="The time period for the data, typically a year." - ) + period: int = Field(description="The time period for the data, typically a year.") stratum_id: int = Field(foreign_key="strata.stratum_id", index=True) reform_id: int = Field( default=0, @@ -156,19 +150,13 @@ def calculate_definition_hash(mapper, connection, target: Stratum): Calculate and set the definition_hash before saving a Stratum instance. """ constraints_history = get_history(target, "constraints_rel") - if not ( - constraints_history.has_changes() or target.definition_hash is None - ): + if not (constraints_history.has_changes() or target.definition_hash is None): return if not target.constraints_rel: # Handle cases with no constraints # Include parent_stratum_id to make hash unique per parent - parent_str = ( - str(target.parent_stratum_id) if target.parent_stratum_id else "" - ) - target.definition_hash = hashlib.sha256( - parent_str.encode("utf-8") - ).hexdigest() + parent_str = str(target.parent_stratum_id) if target.parent_stratum_id else "" + target.definition_hash = hashlib.sha256(parent_str.encode("utf-8")).hexdigest() return constraint_strings = [ @@ -178,9 +166,7 @@ def calculate_definition_hash(mapper, connection, target: Stratum): constraint_strings.sort() # Include parent_stratum_id in the hash to ensure uniqueness per parent - parent_str = ( - str(target.parent_stratum_id) if target.parent_stratum_id else "" - ) + parent_str = str(target.parent_stratum_id) if target.parent_stratum_id else "" fingerprint_text = parent_str + "\n" + "\n".join(constraint_strings) h = hashlib.sha256(fingerprint_text.encode("utf-8")) target.definition_hash = h.hexdigest() @@ -241,10 +227,7 @@ def _validate_geographic_consistency(parent_rows, child_constraints): ) # CD must belong to the parent state. - if ( - "state_fips" in parent_dict - and "congressional_district_geoid" in child_dict - ): + if "state_fips" in parent_dict and "congressional_district_geoid" in child_dict: parent_state = int(parent_dict["state_fips"]) child_cd = int(child_dict["congressional_district_geoid"]) cd_state = child_cd // 100 @@ -288,8 +271,7 @@ def validate_parent_child_constraints(mapper, connection, target: Stratum): return child_set = { - (c.constraint_variable, c.operation, c.value) - for c in target.constraints_rel + (c.constraint_variable, c.operation, c.value) for c in target.constraints_rel } for var, op, val in parent_rows: diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index aa656c9d..8f6f051c 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -45,16 +45,12 @@ def fetch_congressional_districts(year): df = df[df["district_number"] >= 0].copy() # Filter out statewide summary records for multi-district states - df["n_districts"] = df.groupby("state_fips")["state_fips"].transform( - "count" - ) + df["n_districts"] = df.groupby("state_fips")["state_fips"].transform("count") df = df[(df["n_districts"] == 1) | (df["district_number"] > 0)].copy() df = df.drop(columns=["n_districts"]) df.loc[df["district_number"] == 0, "district_number"] = 1 - df["congressional_district_geoid"] = ( - df["state_fips"] * 100 + df["district_number"] - ) + df["congressional_district_geoid"] = df["state_fips"] * 100 + df["district_number"] df = df[ [ @@ -130,9 +126,7 @@ def main(): # Fetch congressional district data cd_df = fetch_congressional_districts(year) - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) with Session(engine) as session: @@ -157,9 +151,7 @@ def main(): # Create state-level strata unique_states = cd_df["state_fips"].unique() for state_fips in sorted(unique_states): - state_name = STATE_NAMES.get( - state_fips, f"State FIPS {state_fips}" - ) + state_name = STATE_NAMES.get(state_fips, f"State FIPS {state_fips}") state_stratum = Stratum( parent_stratum_id=us_stratum_id, notes=state_name, diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index 1a12f372..db5e54da 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -66,9 +66,7 @@ def transform_age_data(age_data, docs): # Filter out Puerto Rico's district and state records # 5001800US7298 = 118th Congress, 5001900US7298 = 119th Congress df_geos = df_data[ - ~df_data["ucgid_str"].isin( - ["5001800US7298", "5001900US7298", "0400000US72"] - ) + ~df_data["ucgid_str"].isin(["5001800US7298", "5001900US7298", "0400000US72"]) ].copy() df = df_geos[["ucgid_str"] + AGE_COLS] @@ -106,9 +104,7 @@ def load_age_data(df_long, geo, year): raise ValueError('geo must be one of "National", "State", "District"') # Prepare to load data ----------- - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) with Session(engine) as session: diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index aa8122a5..f2b17795 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -104,9 +104,7 @@ def make_records( f"WARNING: A59664 values appear to be in thousands (max={max_value:,.0f})" ) print("The IRS may have fixed their data inconsistency.") - print( - "Please verify and remove the special case handling if confirmed." - ) + print("Please verify and remove the special case handling if confirmed.") # Don't apply the fix - data appears to already be in thousands else: # Convert from dollars to thousands to match other columns @@ -162,9 +160,7 @@ def convert_district_data( """Transforms data from pre- to post- 2020 census districts""" df = input_df.copy() old_districts_df = df[df["ucgid_str"].str.startswith("5001800US")].copy() - old_districts_df = old_districts_df.sort_values("ucgid_str").reset_index( - drop=True - ) + old_districts_df = old_districts_df.sort_values("ucgid_str").reset_index(drop=True) old_values = old_districts_df["target_value"].to_numpy() new_values = mapping_matrix.T @ old_values @@ -289,19 +285,15 @@ def transform_soi_data(raw_df): # State ------------------- # You've got agi_stub == 0 in here, which you want to use any time you don't want to # divide data by AGI classes (i.e., agi_stub) - state_df = raw_df.copy().loc[ - (raw_df.STATE != "US") & (raw_df.CONG_DISTRICT == 0) - ] - state_df["ucgid_str"] = "0400000US" + state_df["STATEFIPS"].astype( - str - ).str.zfill(2) + state_df = raw_df.copy().loc[(raw_df.STATE != "US") & (raw_df.CONG_DISTRICT == 0)] + state_df["ucgid_str"] = "0400000US" + state_df["STATEFIPS"].astype(str).str.zfill(2) # District ------------------ district_df = raw_df.copy().loc[(raw_df.CONG_DISTRICT > 0)] - max_cong_district_by_state = raw_df.groupby("STATE")[ - "CONG_DISTRICT" - ].transform("max") + max_cong_district_by_state = raw_df.groupby("STATE")["CONG_DISTRICT"].transform( + "max" + ) district_df = raw_df.copy().loc[ (raw_df["CONG_DISTRICT"] > 0) | (max_cong_district_by_state == 0) ] @@ -370,9 +362,7 @@ def transform_soi_data(raw_df): # Pre- to Post- 2020 Census redisticting mapping = get_district_mapping() converted = [ - convert_district_data( - r, mapping["mapping_matrix"], mapping["new_codes"] - ) + convert_district_data(r, mapping["mapping_matrix"], mapping["new_codes"]) for r in records ] @@ -382,9 +372,7 @@ def transform_soi_data(raw_df): def load_soi_data(long_dfs, year): """Load a list of databases into the db, critically dependent on order""" - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) session = Session(engine) @@ -458,9 +446,7 @@ def load_soi_data(long_dfs, year): filer_strata["state"][state_fips] = state_filer_stratum.stratum_id # District filer strata - for district_geoid, district_geo_stratum_id in geo_strata[ - "district" - ].items(): + for district_geoid, district_geo_stratum_id in geo_strata["district"].items(): # Check if district filer stratum exists district_filer_stratum = ( session.query(Stratum) @@ -492,9 +478,7 @@ def load_soi_data(long_dfs, year): session.add(district_filer_stratum) session.flush() - filer_strata["district"][ - district_geoid - ] = district_filer_stratum.stratum_id + filer_strata["district"][district_geoid] = district_filer_stratum.stratum_id session.commit() @@ -525,9 +509,7 @@ def load_soi_data(long_dfs, year): ) ] elif geo_info["type"] == "state": - parent_stratum_id = filer_strata["state"][ - geo_info["state_fips"] - ] + parent_stratum_id = filer_strata["state"][geo_info["state_fips"]] note = f"State FIPS {geo_info['state_fips']} EITC received with {n_children} children (filers)" constraints = [ StratumConstraint( @@ -636,9 +618,7 @@ def load_soi_data(long_dfs, year): # Store lookup for later use if geo_info["type"] == "national": - eitc_stratum_lookup["national"][ - n_children - ] = new_stratum.stratum_id + eitc_stratum_lookup["national"][n_children] = new_stratum.stratum_id elif geo_info["type"] == "state": key = (geo_info["state_fips"], n_children) eitc_stratum_lookup["state"][key] = new_stratum.stratum_id @@ -652,8 +632,7 @@ def load_soi_data(long_dfs, year): first_agi_index = [ i for i in range(len(long_dfs)) - if long_dfs[i][["target_variable"]].values[0] - == "adjusted_gross_income" + if long_dfs[i][["target_variable"]].values[0] == "adjusted_gross_income" and long_dfs[i][["breakdown_variable"]].values[0] == "one" ][0] for j in range(8, first_agi_index, 2): @@ -676,17 +655,13 @@ def load_soi_data(long_dfs, year): parent_stratum_id = filer_strata["national"] geo_description = "National" elif geo_info["type"] == "state": - parent_stratum_id = filer_strata["state"][ - geo_info["state_fips"] - ] + parent_stratum_id = filer_strata["state"][geo_info["state_fips"]] geo_description = f"State {geo_info['state_fips']}" elif geo_info["type"] == "district": parent_stratum_id = filer_strata["district"][ geo_info["congressional_district_geoid"] ] - geo_description = ( - f"CD {geo_info['congressional_district_geoid']}" - ) + geo_description = f"CD {geo_info['congressional_district_geoid']}" # Create child stratum with constraint for this IRS variable # Note: This stratum will have the constraint that amount_variable > 0 @@ -741,9 +716,7 @@ def load_soi_data(long_dfs, year): StratumConstraint( constraint_variable="congressional_district_geoid", operation="==", - value=str( - geo_info["congressional_district_geoid"] - ), + value=str(geo_info["congressional_district_geoid"]), ) ) @@ -805,9 +778,7 @@ def load_soi_data(long_dfs, year): elif geo_info["type"] == "district": stratum = session.get( Stratum, - filer_strata["district"][ - geo_info["congressional_district_geoid"] - ], + filer_strata["district"][geo_info["congressional_district_geoid"]], ) # Check if target already exists @@ -822,9 +793,7 @@ def load_soi_data(long_dfs, year): ) if existing_target: - existing_target.value = agi_values.iloc[i][ - ["target_value"] - ].values[0] + existing_target.value = agi_values.iloc[i][["target_value"]].values[0] else: stratum.targets_rel.append( Target( @@ -901,9 +870,7 @@ def load_soi_data(long_dfs, year): person_count = agi_df.iloc[i][["target_value"]].values[0] if geo_info["type"] == "state": - parent_stratum_id = filer_strata["state"][ - geo_info["state_fips"] - ] + parent_stratum_id = filer_strata["state"][geo_info["state_fips"]] note = f"State FIPS {geo_info['state_fips']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" constraints = [ StratumConstraint( @@ -1000,9 +967,9 @@ def load_soi_data(long_dfs, year): session.flush() if geo_info["type"] == "state": - agi_stratum_lookup["state"][ - geo_info["state_fips"] - ] = new_stratum.stratum_id + agi_stratum_lookup["state"][geo_info["state_fips"]] = ( + new_stratum.stratum_id + ) elif geo_info["type"] == "district": agi_stratum_lookup["district"][ geo_info["congressional_district_geoid"] diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index dfc19cdc..2c467799 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -116,9 +116,7 @@ def transform_administrative_medicaid_data(state_admin_df, year): ].sort_values("Reporting Period", ascending=False) if not state_history.empty: - fallback_value = state_history.iloc[0][ - "Total Medicaid Enrollment" - ] + fallback_value = state_history.iloc[0]["Total Medicaid Enrollment"] fallback_period = state_history.iloc[0]["Reporting Period"] print( f" {state_abbrev}: Using {fallback_value:,.0f} from period {fallback_period}" @@ -153,9 +151,7 @@ def transform_survey_medicaid_data(cd_survey_df): def load_medicaid_data(long_state, long_cd, year): - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) with Session(engine) as session: @@ -222,9 +218,7 @@ def load_medicaid_data(long_state, long_cd, year): ) session.add(new_stratum) session.flush() - medicaid_stratum_lookup["state"][ - state_fips - ] = new_stratum.stratum_id + medicaid_stratum_lookup["state"][state_fips] = new_stratum.stratum_id # District ------------------- if long_cd is None: diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py index 2b78b6d6..0e87aa84 100644 --- a/policyengine_us_data/db/etl_national_targets.py +++ b/policyengine_us_data/db/etl_national_targets.py @@ -423,14 +423,10 @@ def transform_national_targets(raw_targets): # Note: income_tax_positive from CBO and eitc from Treasury need # filer constraint cbo_non_tax = [ - t - for t in raw_targets["cbo_targets"] - if t["variable"] != "income_tax_positive" + t for t in raw_targets["cbo_targets"] if t["variable"] != "income_tax_positive" ] cbo_tax = [ - t - for t in raw_targets["cbo_targets"] - if t["variable"] == "income_tax_positive" + t for t in raw_targets["cbo_targets"] if t["variable"] == "income_tax_positive" ] all_direct_targets = raw_targets["direct_sum_targets"] + cbo_non_tax @@ -443,14 +439,10 @@ def transform_national_targets(raw_targets): ) direct_df = ( - pd.DataFrame(all_direct_targets) - if all_direct_targets - else pd.DataFrame() + pd.DataFrame(all_direct_targets) if all_direct_targets else pd.DataFrame() ) tax_filer_df = ( - pd.DataFrame(all_tax_filer_targets) - if all_tax_filer_targets - else pd.DataFrame() + pd.DataFrame(all_tax_filer_targets) if all_tax_filer_targets else pd.DataFrame() ) # Conditional targets stay as list for special processing @@ -459,9 +451,7 @@ def transform_national_targets(raw_targets): return direct_df, tax_filer_df, conditional_targets -def load_national_targets( - direct_targets_df, tax_filer_df, conditional_targets -): +def load_national_targets(direct_targets_df, tax_filer_df, conditional_targets): """ Load national targets into the database. @@ -475,17 +465,13 @@ def load_national_targets( List of conditional count targets requiring strata """ - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) with Session(engine) as session: # Get the national stratum us_stratum = ( - session.query(Stratum) - .filter(Stratum.parent_stratum_id == None) - .first() + session.query(Stratum).filter(Stratum.parent_stratum_id == None).first() ) if not us_stratum: @@ -511,9 +497,7 @@ def load_national_targets( notes_parts = [] if pd.notna(target_data.get("notes")): notes_parts.append(target_data["notes"]) - notes_parts.append( - f"Source: {target_data.get('source', 'Unknown')}" - ) + notes_parts.append(f"Source: {target_data.get('source', 'Unknown')}") combined_notes = " | ".join(notes_parts) if existing_target: @@ -583,9 +567,7 @@ def load_national_targets( notes_parts = [] if pd.notna(target_data.get("notes")): notes_parts.append(target_data["notes"]) - notes_parts.append( - f"Source: {target_data.get('source', 'Unknown')}" - ) + notes_parts.append(f"Source: {target_data.get('source', 'Unknown')}") combined_notes = " | ".join(notes_parts) if existing_target: @@ -699,23 +681,17 @@ def load_national_targets( ] session.add(new_stratum) - print( - f"Created stratum and target for {constraint_var} enrollment" - ) + print(f"Created stratum and target for {constraint_var} enrollment") session.commit() total_targets = ( - len(direct_targets_df) - + len(tax_filer_df) - + len(conditional_targets) + len(direct_targets_df) + len(tax_filer_df) + len(conditional_targets) ) print(f"\nSuccessfully loaded {total_targets} national targets") print(f" - {len(direct_targets_df)} direct sum targets") print(f" - {len(tax_filer_df)} tax filer targets") - print( - f" - {len(conditional_targets)} enrollment count targets (as strata)" - ) + print(f" - {len(conditional_targets)} enrollment count targets (as strata)") def main(): @@ -730,8 +706,8 @@ def main(): # Transform print("Transforming targets...") - direct_targets_df, tax_filer_df, conditional_targets = ( - transform_national_targets(raw_targets) + direct_targets_df, tax_filer_df, conditional_targets = transform_national_targets( + raw_targets ) # Load diff --git a/policyengine_us_data/db/etl_pregnancy.py b/policyengine_us_data/db/etl_pregnancy.py index c237d262..e8756cfb 100644 --- a/policyengine_us_data/db/etl_pregnancy.py +++ b/policyengine_us_data/db/etl_pregnancy.py @@ -219,9 +219,7 @@ def transform_pregnancy_data( df = births_df.merge(pop_df, on="state_abbrev") df["state_fips"] = df["state_abbrev"].map(STATE_ABBREV_TO_FIPS) # Point-in-time pregnancy count. - df["pregnancy_target"] = ( - df["births"] * PREGNANCY_DURATION_FRACTION - ).round() + df["pregnancy_target"] = (df["births"] * PREGNANCY_DURATION_FRACTION).round() # Rate for stochastic assignment in the CPS build. df["pregnancy_rate"] = ( df["births"] / df["female_15_44"] @@ -268,9 +266,7 @@ def load_pregnancy_data( for _, row in df.iterrows(): state_fips = int(row["state_fips"]) if state_fips not in geo_strata["state"]: - logger.warning( - f"No geographic stratum for FIPS {state_fips}, skipping" - ) + logger.warning(f"No geographic stratum for FIPS {state_fips}, skipping") continue parent_id = geo_strata["state"][state_fips] @@ -362,9 +358,7 @@ def main(): except Exception as e: logger.warning(f"ACS {acs_year} not available: {e}") if pop_df is None: - raise RuntimeError( - f"No ACS population data for {year - 1} or {year - 2}" - ) + raise RuntimeError(f"No ACS population data for {year - 1} or {year - 2}") df = transform_pregnancy_data(births_df, pop_df) diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index 48cb7e77..dc5975a4 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -154,9 +154,7 @@ def transform_survey_snap_data(raw_df): def load_administrative_snap_data(df_states, year): - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) with Session(engine) as session: @@ -244,9 +242,7 @@ def load_survey_snap_data(survey_df, year, snap_stratum_lookup): load_administrative_snap_data, so we don't recreate them. """ - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) with Session(engine) as session: diff --git a/policyengine_us_data/db/etl_state_income_tax.py b/policyengine_us_data/db/etl_state_income_tax.py index a9ffa35c..95fbc285 100644 --- a/policyengine_us_data/db/etl_state_income_tax.py +++ b/policyengine_us_data/db/etl_state_income_tax.py @@ -320,11 +320,7 @@ def main(): # Print summary total_collections = transformed_df["income_tax_collections"].sum() states_with_tax = len( - [ - s - for s in transformed_df["state_abbrev"] - if s not in NO_INCOME_TAX_STATES - ] + [s for s in transformed_df["state_abbrev"] if s not in NO_INCOME_TAX_STATES] ) logger.info( @@ -337,9 +333,7 @@ def main(): # Print Ohio specifically (for the issue reference) ohio_row = transformed_df[transformed_df["state_abbrev"] == "OH"].iloc[0] - logger.info( - f" Ohio (OH): ${ohio_row['income_tax_collections'] / 1e9:.2f}B" - ) + logger.info(f" Ohio (OH): ${ohio_row['income_tax_collections'] / 1e9:.2f}B") if __name__ == "__main__": diff --git a/policyengine_us_data/db/validate_database.py b/policyengine_us_data/db/validate_database.py index 2fa819f2..b57a83c3 100644 --- a/policyengine_us_data/db/validate_database.py +++ b/policyengine_us_data/db/validate_database.py @@ -9,9 +9,7 @@ import pandas as pd from policyengine_us.system import system -conn = sqlite3.connect( - "policyengine_us_data/storage/calibration/policy_data.db" -) +conn = sqlite3.connect("policyengine_us_data/storage/calibration/policy_data.db") stratum_constraints_df = pd.read_sql("SELECT * FROM stratum_constraints", conn) targets_df = pd.read_sql("SELECT * FROM targets", conn) diff --git a/policyengine_us_data/db/validate_hierarchy.py b/policyengine_us_data/db/validate_hierarchy.py index 353c09ee..1c555703 100644 --- a/policyengine_us_data/db/validate_hierarchy.py +++ b/policyengine_us_data/db/validate_hierarchy.py @@ -31,9 +31,7 @@ def validate_geographic_hierarchy(session): "ERROR: No US-level stratum found (should have parent_stratum_id = None)" ) else: - print( - f"✓ US stratum found: {us_stratum.notes} (ID: {us_stratum.stratum_id})" - ) + print(f"✓ US stratum found: {us_stratum.notes} (ID: {us_stratum.stratum_id})") # Check it has no constraints us_constraints = session.exec( @@ -89,14 +87,10 @@ def validate_geographic_hierarchy(session): c for c in constraints if c.constraint_variable == "state_fips" ] if not state_fips_constraint: - errors.append( - f"ERROR: State '{state.notes}' has no state_fips constraint" - ) + errors.append(f"ERROR: State '{state.notes}' has no state_fips constraint") else: state_ids[state.stratum_id] = state.notes - print( - f" - {state.notes}: state_fips = {state_fips_constraint[0].value}" - ) + print(f" - {state.notes}: state_fips = {state_fips_constraint[0].value}") # Check congressional districts print("\nChecking Congressional Districts...") @@ -112,14 +106,10 @@ def validate_geographic_hierarchy(session): ) ).all() constraint_vars = {c.constraint_variable for c in constraints} - if ( - "congressional_district_geoid" in constraint_vars - and constraint_vars - <= { - "state_fips", - "congressional_district_geoid", - } - ): + if "congressional_district_geoid" in constraint_vars and constraint_vars <= { + "state_fips", + "congressional_district_geoid", + }: all_cds.append(s) print(f"✓ Found {len(all_cds)} congressional/delegate districts") @@ -161,9 +151,7 @@ def validate_geographic_hierarchy(session): wyoming_cds.append(child) if len(wyoming_cds) != 1: - errors.append( - f"ERROR: Wyoming should have 1 CD, found {len(wyoming_cds)}" - ) + errors.append(f"ERROR: Wyoming should have 1 CD, found {len(wyoming_cds)}") else: print(f"✓ Wyoming has correct number of CDs: 1") @@ -187,9 +175,7 @@ def validate_geographic_hierarchy(session): for cd in wrong_parent_cds[:5]: errors.append(f" - {cd.notes}") else: - print( - "✓ No congressional districts incorrectly parented to Wyoming" - ) + print("✓ No congressional districts incorrectly parented to Wyoming") return errors @@ -240,9 +226,7 @@ def validate_demographic_strata(session): if actual == expected_total: print(f"✓ {domain}: {actual} strata") elif actual == 0: - errors.append( - f"ERROR: {domain} has no strata, expected {expected_total}" - ) + errors.append(f"ERROR: {domain} has no strata, expected {expected_total}") else: errors.append( f"WARNING: {domain} has {actual} strata, expected {expected_total}" @@ -322,18 +306,12 @@ def validate_constraint_uniqueness(session): else: hash_counts[stratum.definition_hash] = [stratum] - duplicates = { - h: strata for h, strata in hash_counts.items() if len(strata) > 1 - } + duplicates = {h: strata for h, strata in hash_counts.items() if len(strata) > 1} if duplicates: - errors.append( - f"ERROR: Found {len(duplicates)} duplicate definition_hashes" - ) + errors.append(f"ERROR: Found {len(duplicates)} duplicate definition_hashes") for hash_val, strata in list(duplicates.items())[:3]: # Show first 3 - errors.append( - f" Hash {hash_val[:10]}... appears {len(strata)} times:" - ) + errors.append(f" Hash {hash_val[:10]}... appears {len(strata)} times:") for s in strata[:3]: errors.append(f" - ID {s.stratum_id}: {s.notes[:50]}") else: @@ -345,9 +323,7 @@ def validate_constraint_uniqueness(session): def main(): """Run all validation checks""" - DATABASE_URL = ( - f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" - ) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) all_errors = [] diff --git a/policyengine_us_data/geography/__init__.py b/policyengine_us_data/geography/__init__.py index 0bcc73f0..f2006819 100644 --- a/policyengine_us_data/geography/__init__.py +++ b/policyengine_us_data/geography/__init__.py @@ -2,9 +2,7 @@ import pandas as pd import os -ZIP_CODE_DATASET_PATH = ( - Path(__file__).parent.parent / "geography" / "zip_codes.csv.gz" -) +ZIP_CODE_DATASET_PATH = Path(__file__).parent.parent / "geography" / "zip_codes.csv.gz" # Avoid circular import error when -us-data is initialized if os.path.exists(ZIP_CODE_DATASET_PATH): diff --git a/policyengine_us_data/geography/county_fips.py b/policyengine_us_data/geography/county_fips.py index 3e5ac518..6bb2b9e9 100644 --- a/policyengine_us_data/geography/county_fips.py +++ b/policyengine_us_data/geography/county_fips.py @@ -21,7 +21,9 @@ def generate_county_fips_2020_dataset(): # COUNTYFP - Three-digit county portion of FIPS (001 for Autauga County, AL, if STATEFP is 01) # COUNTYNAME - County name - COUNTY_FIPS_2020_URL = "https://www2.census.gov/geo/docs/reference/codes2020/national_county2020.txt" + COUNTY_FIPS_2020_URL = ( + "https://www2.census.gov/geo/docs/reference/codes2020/national_county2020.txt" + ) # Download the base tab-delimited data file response = requests.get(COUNTY_FIPS_2020_URL) @@ -68,9 +70,7 @@ def generate_county_fips_2020_dataset(): csv_buffer = BytesIO() # Save CSV into buffer object and reset pointer - county_fips.to_csv( - csv_buffer, index=False, compression="gzip", encoding="utf-8" - ) + county_fips.to_csv(csv_buffer, index=False, compression="gzip", encoding="utf-8") csv_buffer.seek(0) # Upload to Hugging Face diff --git a/policyengine_us_data/geography/create_zip_code_dataset.py b/policyengine_us_data/geography/create_zip_code_dataset.py index eb154cf7..981b5de5 100644 --- a/policyengine_us_data/geography/create_zip_code_dataset.py +++ b/policyengine_us_data/geography/create_zip_code_dataset.py @@ -51,7 +51,5 @@ zcta.set_index("zcta").population[zip_code.zcta].values / zip_code.groupby("zcta").zip_code.count()[zip_code.zcta].values ) -zip_code["county"] = ( - zcta_to_county.set_index("zcta").county[zip_code.zcta].values -) +zip_code["county"] = zcta_to_county.set_index("zcta").county[zip_code.zcta].values zip_code.to_csv("zip_codes.csv", compression="gzip") diff --git a/policyengine_us_data/parameters/__init__.py b/policyengine_us_data/parameters/__init__.py index 2fcddb5a..dc385f8e 100644 --- a/policyengine_us_data/parameters/__init__.py +++ b/policyengine_us_data/parameters/__init__.py @@ -65,8 +65,6 @@ def load_take_up_rate(variable_name: str, year: int = 2018): break if applicable_value is None: - raise ValueError( - f"No take-up rate found for {variable_name} in {year}" - ) + raise ValueError(f"No take-up rate found for {variable_name} in {year}") return applicable_value diff --git a/policyengine_us_data/storage/calibration_targets/audit_county_enum.py b/policyengine_us_data/storage/calibration_targets/audit_county_enum.py index 4849a10e..fcaf443f 100644 --- a/policyengine_us_data/storage/calibration_targets/audit_county_enum.py +++ b/policyengine_us_data/storage/calibration_targets/audit_county_enum.py @@ -109,9 +109,7 @@ def print_categorized_report(invalid_entries, county_to_states): print("\n" + "=" * 60) print("WRONG STATE ASSIGNMENTS") print("=" * 60) - for name, wrong_state, correct_states in sorted( - invalid_entries["wrong_state"] - ): + for name, wrong_state, correct_states in sorted(invalid_entries["wrong_state"]): print(f" {name}") print(f" Listed as: {wrong_state}") print(f" Actually exists in: {', '.join(sorted(correct_states))}") diff --git a/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py b/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py index 6f55e3f7..f2b634e0 100644 --- a/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py +++ b/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py @@ -78,9 +78,7 @@ def build_block_cd_distributions(): # Create CD geoid in our format: state_fips * 100 + district # Examples: AL-1 = 101, NY-10 = 3610, DC = 1198 - df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype( - int - ) + df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype(int) # Step 4: Calculate P(block|CD) print("\nCalculating block probabilities...") @@ -97,9 +95,7 @@ def build_block_cd_distributions(): output = df[["cd_geoid", "GEOID", "probability"]].rename( columns={"GEOID": "block_geoid"} ) - output = output.sort_values( - ["cd_geoid", "probability"], ascending=[True, False] - ) + output = output.sort_values(["cd_geoid", "probability"], ascending=[True, False]) # Step 6: Save as gzipped CSV (parquet requires pyarrow) output_path = STORAGE_FOLDER / "block_cd_distributions.csv.gz" diff --git a/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py b/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py index 418e725f..ed0d8cc1 100644 --- a/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py +++ b/policyengine_us_data/storage/calibration_targets/make_block_crosswalk.py @@ -60,9 +60,7 @@ def download_state_baf(state_fips: str, state_abbr: str) -> dict: ) # Place (City/CDP) - place_file = ( - f"BlockAssign_ST{state_fips}_{state_abbr}_INCPLACE_CDP.txt" - ) + place_file = f"BlockAssign_ST{state_fips}_{state_abbr}_INCPLACE_CDP.txt" if place_file in z.namelist(): df = pd.read_csv(z.open(place_file), sep="|", dtype=str) results["place"] = df.rename( @@ -168,23 +166,17 @@ def build_block_crosswalk(): # Merge other geographies if "sldl" in bafs: - df = df.merge( - bafs["sldl"], on="block_geoid", how="left" - ) + df = df.merge(bafs["sldl"], on="block_geoid", how="left") else: df["sldl"] = None if "place" in bafs: - df = df.merge( - bafs["place"], on="block_geoid", how="left" - ) + df = df.merge(bafs["place"], on="block_geoid", how="left") else: df["place_fips"] = None if "vtd" in bafs: - df = df.merge( - bafs["vtd"], on="block_geoid", how="left" - ) + df = df.merge(bafs["vtd"], on="block_geoid", how="left") else: df["vtd"] = None diff --git a/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py b/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py index ba68a556..2c91f1ca 100644 --- a/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py +++ b/policyengine_us_data/storage/calibration_targets/make_county_cd_distributions.py @@ -126,15 +126,11 @@ def build_county_cd_distributions(): # Create CD geoid in our format: state_fips * 100 + district # Examples: AL-1 = 101, NY-10 = 3610, DC = 1198 - df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype( - int - ) + df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype(int) # Step 4: Aggregate by (CD, county) print("\nAggregating population by CD and county...") - cd_county_pop = ( - df.groupby(["cd_geoid", "county_fips"])["POP20"].sum().reset_index() - ) + cd_county_pop = df.groupby(["cd_geoid", "county_fips"])["POP20"].sum().reset_index() print(f" Unique CD-county pairs: {len(cd_county_pop):,}") # Step 5: Calculate P(county|CD) @@ -151,9 +147,7 @@ def build_county_cd_distributions(): # Step 6: Map county FIPS to enum names print("\nMapping county FIPS to enum names...") fips_to_enum = build_county_fips_to_enum_mapping() - cd_county_pop["county_name"] = cd_county_pop["county_fips"].map( - fips_to_enum - ) + cd_county_pop["county_name"] = cd_county_pop["county_fips"].map(fips_to_enum) # Check for unmapped counties unmapped = cd_county_pop[cd_county_pop["county_name"].isna()] @@ -177,9 +171,7 @@ def build_county_cd_distributions(): # Step 8: Save CSV output = cd_county_pop[["cd_geoid", "county_name", "probability"]] - output = output.sort_values( - ["cd_geoid", "probability"], ascending=[True, False] - ) + output = output.sort_values(["cd_geoid", "probability"], ascending=[True, False]) output_path = STORAGE_FOLDER / "county_cd_distributions.csv" output.to_csv(output_path, index=False) diff --git a/policyengine_us_data/storage/calibration_targets/make_district_mapping.py b/policyengine_us_data/storage/calibration_targets/make_district_mapping.py index 2b930a2d..bfb4936e 100644 --- a/policyengine_us_data/storage/calibration_targets/make_district_mapping.py +++ b/policyengine_us_data/storage/calibration_targets/make_district_mapping.py @@ -91,9 +91,7 @@ def fetch_block_to_district_map(congress: int) -> pd.DataFrame: return bef[["GEOID", f"CD{congress}"]] else: - raise ValueError( - f"Congress {congress} is not supported by this function." - ) + raise ValueError(f"Congress {congress} is not supported by this function.") def fetch_block_population(state) -> pd.DataFrame: @@ -145,9 +143,7 @@ def fetch_block_population(state) -> pd.DataFrame: geo_df = pd.DataFrame(geo_records, columns=["LOGRECNO", "GEOID"]) # ---------------- P-file: pull total-population cell ---------------------- - p1_records = [ - (p[4], int(p[5])) for p in map(lambda x: x.split("|"), p1_lines) - ] + p1_records = [(p[4], int(p[5])) for p in map(lambda x: x.split("|"), p1_lines)] p1_df = pd.DataFrame(p1_records, columns=["LOGRECNO", "P0010001"]) # ---------------- Merge & finish ----------------------------------------- diff --git a/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py b/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py index da8b5412..3199a56a 100644 --- a/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py +++ b/policyengine_us_data/storage/calibration_targets/pull_hardcoded_targets.py @@ -42,13 +42,9 @@ def pull_hardcoded_targets(): "VARIABLE": list(HARD_CODED_TOTALS.keys()), "VALUE": list(HARD_CODED_TOTALS.values()), "IS_COUNT": [0.0] - * len( - HARD_CODED_TOTALS - ), # All values are monetary amounts, not counts + * len(HARD_CODED_TOTALS), # All values are monetary amounts, not counts "BREAKDOWN_VARIABLE": [np.nan] - * len( - HARD_CODED_TOTALS - ), # No breakdown variable for hardcoded targets + * len(HARD_CODED_TOTALS), # No breakdown variable for hardcoded targets "LOWER_BOUND": [np.nan] * len(HARD_CODED_TOTALS), "UPPER_BOUND": [np.nan] * len(HARD_CODED_TOTALS), } diff --git a/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py b/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py index 1830bdb3..202286e7 100644 --- a/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py +++ b/policyengine_us_data/storage/calibration_targets/pull_snap_targets.py @@ -84,7 +84,9 @@ def extract_usda_snap_data(year=2023): session.headers.update(headers) # Try to visit the main page first to get any necessary cookies - main_page = "https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap" + main_page = ( + "https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap" + ) try: session.get(main_page, timeout=30) except: @@ -167,9 +169,7 @@ def extract_usda_snap_data(year=2023): .reset_index(drop=True) ) df_states["GEO_ID"] = "0400000US" + df_states["STATE_FIPS"] - df_states["GEO_NAME"] = "state_" + df_states["State"].map( - STATE_NAME_TO_ABBREV - ) + df_states["GEO_NAME"] = "state_" + df_states["State"].map(STATE_NAME_TO_ABBREV) count_df = df_states[["GEO_ID", "GEO_NAME"]].copy() count_df["VALUE"] = df_states["Households"] diff --git a/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py b/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py index 59050a1b..ce6d9f88 100644 --- a/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py +++ b/policyengine_us_data/storage/calibration_targets/pull_soi_targets.py @@ -129,26 +129,17 @@ def pull_national_soi_variable( national_df: Optional[pd.DataFrame] = None, ) -> pd.DataFrame: """Download and save national AGI totals.""" - df = pd.read_excel( - "https://www.irs.gov/pub/irs-soi/22in54us.xlsx", skiprows=7 - ) + df = pd.read_excel("https://www.irs.gov/pub/irs-soi/22in54us.xlsx", skiprows=7) assert ( - np.abs( - df.iloc[soi_variable_ident, 1] - - df.iloc[soi_variable_ident, 2:12].sum() - ) + np.abs(df.iloc[soi_variable_ident, 1] - df.iloc[soi_variable_ident, 2:12].sum()) < 100 ), "Row 0 doesn't add up — check the file." agi_values = df.iloc[soi_variable_ident, 2:12].astype(int).to_numpy() - agi_values = np.concatenate( - [agi_values[:8], [agi_values[8] + agi_values[9]]] - ) + agi_values = np.concatenate([agi_values[:8], [agi_values[8] + agi_values[9]]]) - agi_brackets = [ - AGI_STUB_TO_BAND[i] for i in range(1, len(SOI_COLUMNS) + 1) - ] + agi_brackets = [AGI_STUB_TO_BAND[i] for i in range(1, len(SOI_COLUMNS) + 1)] result = pd.DataFrame( { @@ -161,9 +152,7 @@ def pull_national_soi_variable( ) # final column order - result = result[ - ["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"] - ] + result = result[["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"]] result["IS_COUNT"] = int(is_count) result["VARIABLE"] = variable_name @@ -186,9 +175,7 @@ def pull_state_soi_variable( state_df: Optional[pd.DataFrame] = None, ) -> pd.DataFrame: """Download and save state AGI totals.""" - df = pd.read_csv( - "https://www.irs.gov/pub/irs-soi/22in55cmcsv.csv", thousands="," - ) + df = pd.read_csv("https://www.irs.gov/pub/irs-soi/22in55cmcsv.csv", thousands=",") merged = ( df[df["AGI_STUB"].isin([9, 10])] @@ -211,17 +198,11 @@ def pull_state_soi_variable( ["GEO_ID", "GEO_NAME", "agi_bracket", soi_variable_ident], ].rename(columns={soi_variable_ident: "VALUE"}) - result["LOWER_BOUND"] = result["agi_bracket"].map( - lambda b: AGI_BOUNDS[b][0] - ) - result["UPPER_BOUND"] = result["agi_bracket"].map( - lambda b: AGI_BOUNDS[b][1] - ) + result["LOWER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][0]) + result["UPPER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][1]) # final column order - result = result[ - ["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"] - ] + result = result[["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"]] result["IS_COUNT"] = int(is_count) result["VARIABLE"] = variable_name @@ -249,9 +230,7 @@ def pull_district_soi_variable( df = df[df["agi_stub"] != 0] df["STATEFIPS"] = df["STATEFIPS"].astype(int).astype(str).str.zfill(2) - df["CONG_DISTRICT"] = ( - df["CONG_DISTRICT"].astype(int).astype(str).str.zfill(2) - ) + df["CONG_DISTRICT"] = df["CONG_DISTRICT"].astype(int).astype(str).str.zfill(2) if SOI_DISTRICT_TAX_YEAR >= 2024: raise RuntimeError( f"SOI tax year {SOI_DISTRICT_TAX_YEAR} may need " @@ -288,12 +267,8 @@ def pull_district_soi_variable( ] ].rename(columns={soi_variable_ident: "VALUE"}) - result["LOWER_BOUND"] = result["agi_bracket"].map( - lambda b: AGI_BOUNDS[b][0] - ) - result["UPPER_BOUND"] = result["agi_bracket"].map( - lambda b: AGI_BOUNDS[b][1] - ) + result["LOWER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][0]) + result["UPPER_BOUND"] = result["agi_bracket"].map(lambda b: AGI_BOUNDS[b][1]) # if redistrict: # result = apply_redistricting(result, variable_name) @@ -308,25 +283,23 @@ def pull_district_soi_variable( # Check that all GEO_IDs are valid produced_codes = set(result["GEO_ID"]) invalid_codes = produced_codes - valid_district_codes - assert ( - not invalid_codes - ), f"Invalid district codes after redistricting: {invalid_codes}" + assert not invalid_codes, ( + f"Invalid district codes after redistricting: {invalid_codes}" + ) # Check we have exactly 436 districts - assert ( - len(produced_codes) == 436 - ), f"Expected 436 districts after redistricting, got {len(produced_codes)}" + assert len(produced_codes) == 436, ( + f"Expected 436 districts after redistricting, got {len(produced_codes)}" + ) # Check that all GEO_IDs successfully mapped to names missing_names = result[result["GEO_NAME"].isna()]["GEO_ID"].unique() - assert ( - len(missing_names) == 0 - ), f"GEO_IDs without names in ID_TO_NAME mapping: {missing_names}" + assert len(missing_names) == 0, ( + f"GEO_IDs without names in ID_TO_NAME mapping: {missing_names}" + ) # final column order - result = result[ - ["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"] - ] + result = result[["GEO_ID", "GEO_NAME", "LOWER_BOUND", "UPPER_BOUND", "VALUE"]] result["IS_COUNT"] = int(is_count) result["VARIABLE"] = variable_name @@ -457,15 +430,11 @@ def combine_geography_levels(districts: Optional[bool] = False) -> None: ) # Get state totals indexed by STATEFIPS - state_totals = state.loc[state_mask].set_index("STATEFIPS")[ - "VALUE" - ] + state_totals = state.loc[state_mask].set_index("STATEFIPS")["VALUE"] # Get district totals grouped by STATEFIPS district_totals = ( - district.loc[district_mask] - .groupby("STATEFIPS")["VALUE"] - .sum() + district.loc[district_mask].groupby("STATEFIPS")["VALUE"].sum() ) # Check and rescale districts for each state @@ -480,12 +449,8 @@ def combine_geography_levels(districts: Optional[bool] = False) -> None: f"Districts' sum does not match {fips} state total for {variable}/{count_type} " f"in bracket [{lower}, {upper}]. Rescaling district targets." ) - rescale_mask = district_mask & ( - district["STATEFIPS"] == fips - ) - district.loc[rescale_mask, "VALUE"] *= ( - s_total / d_total - ) + rescale_mask = district_mask & (district["STATEFIPS"] == fips) + district.loc[rescale_mask, "VALUE"] *= s_total / d_total # Combine all data combined = pd.concat( diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index d4f7a070..7af0da04 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -103,9 +103,7 @@ def _check_group_has_data(f, name): ) # At least one income group must have data - has_income = any( - _check_group_has_data(f, g) for g in INCOME_GROUPS - ) + has_income = any(_check_group_has_data(f, g) for g in INCOME_GROUPS) if not has_income: errors.append( f"No income data found. Need at least one of " @@ -127,9 +125,7 @@ def _check_group_has_data(f, name): try: dataset_cls = FILENAME_TO_DATASET.get(filename) if dataset_cls is None: - raise DatasetValidationError( - f"No dataset class registered for {filename}" - ) + raise DatasetValidationError(f"No dataset class registered for {filename}") sim = Microsimulation(dataset=dataset_cls) year = 2024 diff --git a/policyengine_us_data/tests/test_calibration/conftest.py b/policyengine_us_data/tests/test_calibration/conftest.py index 35449156..0698cef0 100644 --- a/policyengine_us_data/tests/test_calibration/conftest.py +++ b/policyengine_us_data/tests/test_calibration/conftest.py @@ -13,6 +13,4 @@ def db_uri(): @pytest.fixture(scope="module") def dataset_path(): - return str( - STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" - ) + return str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") diff --git a/policyengine_us_data/tests/test_calibration/create_test_fixture.py b/policyengine_us_data/tests/test_calibration/create_test_fixture.py index 00334734..2fadeeeb 100644 --- a/policyengine_us_data/tests/test_calibration/create_test_fixture.py +++ b/policyengine_us_data/tests/test_calibration/create_test_fixture.py @@ -30,9 +30,7 @@ def create_test_fixture(): # Household-level arrays household_ids = np.arange(N_HOUSEHOLDS, dtype=np.int32) - household_weights = np.random.uniform(500, 3000, N_HOUSEHOLDS).astype( - np.float32 - ) + household_weights = np.random.uniform(500, 3000, N_HOUSEHOLDS).astype(np.float32) # Assign households to states (use NC=37 and AK=2 for testing) # 40 households in NC, 10 in AK @@ -102,18 +100,14 @@ def create_test_fixture(): f["household_id"].create_dataset(TIME_PERIOD, data=household_ids) f.create_group("household_weight") - f["household_weight"].create_dataset( - TIME_PERIOD, data=household_weights - ) + f["household_weight"].create_dataset(TIME_PERIOD, data=household_weights) # Person variables f.create_group("person_id") f["person_id"].create_dataset(TIME_PERIOD, data=person_ids) f.create_group("person_household_id") - f["person_household_id"].create_dataset( - TIME_PERIOD, data=person_household_ids - ) + f["person_household_id"].create_dataset(TIME_PERIOD, data=person_household_ids) f.create_group("person_weight") f["person_weight"].create_dataset(TIME_PERIOD, data=person_weights) @@ -122,18 +116,14 @@ def create_test_fixture(): f["age"].create_dataset(TIME_PERIOD, data=ages) f.create_group("employment_income") - f["employment_income"].create_dataset( - TIME_PERIOD, data=employment_income - ) + f["employment_income"].create_dataset(TIME_PERIOD, data=employment_income) # Tax unit f.create_group("tax_unit_id") f["tax_unit_id"].create_dataset(TIME_PERIOD, data=tax_unit_ids) f.create_group("person_tax_unit_id") - f["person_tax_unit_id"].create_dataset( - TIME_PERIOD, data=person_tax_unit_ids - ) + f["person_tax_unit_id"].create_dataset(TIME_PERIOD, data=person_tax_unit_ids) f.create_group("tax_unit_weight") f["tax_unit_weight"].create_dataset(TIME_PERIOD, data=tax_unit_weights) @@ -143,9 +133,7 @@ def create_test_fixture(): f["spm_unit_id"].create_dataset(TIME_PERIOD, data=spm_unit_ids) f.create_group("person_spm_unit_id") - f["person_spm_unit_id"].create_dataset( - TIME_PERIOD, data=person_spm_unit_ids - ) + f["person_spm_unit_id"].create_dataset(TIME_PERIOD, data=person_spm_unit_ids) f.create_group("spm_unit_weight") f["spm_unit_weight"].create_dataset(TIME_PERIOD, data=spm_unit_weights) @@ -155,9 +143,7 @@ def create_test_fixture(): f["family_id"].create_dataset(TIME_PERIOD, data=family_ids) f.create_group("person_family_id") - f["person_family_id"].create_dataset( - TIME_PERIOD, data=person_family_ids - ) + f["person_family_id"].create_dataset(TIME_PERIOD, data=person_family_ids) f.create_group("family_weight") f["family_weight"].create_dataset(TIME_PERIOD, data=family_weights) @@ -172,9 +158,7 @@ def create_test_fixture(): ) f.create_group("marital_unit_weight") - f["marital_unit_weight"].create_dataset( - TIME_PERIOD, data=marital_unit_weights - ) + f["marital_unit_weight"].create_dataset(TIME_PERIOD, data=marital_unit_weights) # Geography (household level) f.create_group("state_fips") diff --git a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py index 122be1fb..81cd925d 100644 --- a/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py +++ b/policyengine_us_data/tests/test_calibration/test_build_matrix_masking.py @@ -15,9 +15,7 @@ from policyengine_us_data.storage import STORAGE_FOLDER -DATASET_PATH = str( - STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" -) +DATASET_PATH = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db") DB_URI = f"sqlite:///{DB_PATH}" @@ -44,9 +42,7 @@ def matrix_result(): sim = Microsimulation(dataset=DATASET_PATH) n_records = sim.calculate("household_id").values.shape[0] - geography = assign_random_geography( - n_records, n_clones=N_CLONES, seed=SEED - ) + geography = assign_random_geography(n_records, n_clones=N_CLONES, seed=SEED) builder = UnifiedMatrixBuilder( db_uri=DB_URI, time_period=2024, @@ -58,9 +54,7 @@ def matrix_result(): target_filter={"domain_variables": ["snap", "medicaid"]}, ) X_csc = X_sparse.tocsc() - national_rows = targets_df[ - targets_df["geo_level"] == "national" - ].index.values + national_rows = targets_df[targets_df["geo_level"] == "national"].index.values district_targets = targets_df[targets_df["geo_level"] == "district"] record_idx = None for ri in range(n_records): @@ -186,11 +180,7 @@ def test_clone_visible_only_to_own_cd(self, matrix_result): vals_0 = X_csc[:, col_0].toarray().ravel() same_state_other_cd = district_targets[ - ( - district_targets["geographic_id"].apply( - lambda g: g.startswith(state_0) - ) - ) + (district_targets["geographic_id"].apply(lambda g: g.startswith(state_0))) & (district_targets["geographic_id"] != cd_0) ] @@ -220,9 +210,7 @@ def test_clone_nonzero_for_own_cd(self, matrix_result): X_csc = X.tocsc() vals_0 = X_csc[:, col_0].toarray().ravel() - any_nonzero = any( - vals_0[row.name] != 0 for _, row in own_cd_targets.iterrows() + any_nonzero = any(vals_0[row.name] != 0 for _, row in own_cd_targets.iterrows()) + assert any_nonzero, ( + f"Clone 0 should have at least one non-zero entry for its own CD {cd_0}" ) - assert ( - any_nonzero - ), f"Clone 0 should have at least one non-zero entry for its own CD {cd_0}" diff --git a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py index 93bc5473..9eb1b6f5 100644 --- a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py +++ b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py @@ -69,9 +69,7 @@ def test_loads_and_normalizes(self, tmp_path): "policyengine_us_data.calibration.clone_and_assign.STORAGE_FOLDER", tmp_path, ): - blocks, cds, states, probs = ( - load_global_block_distribution.__wrapped__() - ) + blocks, cds, states, probs = load_global_block_distribution.__wrapped__() assert len(blocks) == 9 np.testing.assert_almost_equal(probs.sum(), 1.0) @@ -140,12 +138,11 @@ def test_no_cd_collisions_across_clones(self, mock_load): r = assign_random_geography(n_records=100, n_clones=3, seed=42) for rec in range(r.n_records): rec_cds = [ - r.cd_geoid[clone * r.n_records + rec] - for clone in range(r.n_clones) + r.cd_geoid[clone * r.n_records + rec] for clone in range(r.n_clones) ] - assert len(rec_cds) == len( - set(rec_cds) - ), f"Record {rec} has duplicate CDs: {rec_cds}" + assert len(rec_cds) == len(set(rec_cds)), ( + f"Record {rec} has duplicate CDs: {rec_cds}" + ) def test_missing_file_raises(self, tmp_path): fake = tmp_path / "nonexistent" diff --git a/policyengine_us_data/tests/test_calibration/test_county_assignment.py b/policyengine_us_data/tests/test_calibration/test_county_assignment.py index 03d7342d..d9b64991 100644 --- a/policyengine_us_data/tests/test_calibration/test_county_assignment.py +++ b/policyengine_us_data/tests/test_calibration/test_county_assignment.py @@ -47,9 +47,7 @@ def test_ny_cd_gets_ny_counties(self): for idx in result: county_name = County._member_names_[idx] # Should end with _NY - assert county_name.endswith( - "_NY" - ), f"Got non-NY county: {county_name}" + assert county_name.endswith("_NY"), f"Got non-NY county: {county_name}" def test_ca_cd_gets_ca_counties(self): """Verify CA CDs get CA counties.""" @@ -58,9 +56,7 @@ def test_ca_cd_gets_ca_counties(self): for idx in result: county_name = County._member_names_[idx] - assert county_name.endswith( - "_CA" - ), f"Got non-CA county: {county_name}" + assert county_name.endswith("_CA"), f"Got non-CA county: {county_name}" class TestCountyIndex: diff --git a/policyengine_us_data/tests/test_calibration/test_puf_impute.py b/policyengine_us_data/tests/test_calibration/test_puf_impute.py index 1bce3cf7..d803486e 100644 --- a/policyengine_us_data/tests/test_calibration/test_puf_impute.py +++ b/policyengine_us_data/tests/test_calibration/test_puf_impute.py @@ -150,9 +150,7 @@ def test_reduces_to_target(self): rng.uniform(500_000, 5_000_000, size=250), ] ) - idx = _stratified_subsample_index( - income, target_n=10_000, top_pct=99.5 - ) + idx = _stratified_subsample_index(income, target_n=10_000, top_pct=99.5) assert len(idx) == 10_000 def test_preserves_top_earners(self): @@ -166,9 +164,7 @@ def test_preserves_top_earners(self): threshold = np.percentile(income, 99.5) n_top = (income >= threshold).sum() - idx = _stratified_subsample_index( - income, target_n=10_000, top_pct=99.5 - ) + idx = _stratified_subsample_index(income, target_n=10_000, top_pct=99.5) selected_income = income[idx] n_top_selected = (selected_income >= threshold).sum() assert n_top_selected == n_top diff --git a/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py b/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py index d8740d16..5b635c79 100644 --- a/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py +++ b/policyengine_us_data/tests/test_calibration/test_retirement_imputation.py @@ -54,14 +54,8 @@ def _make_mock_data(n_persons=20, n_households=5, time_period=2024): "person_household_id": {time_period: hh_ids_person}, "person_tax_unit_id": {time_period: hh_ids_person.copy()}, "person_spm_unit_id": {time_period: hh_ids_person.copy()}, - "age": { - time_period: rng.integers(18, 80, size=n_persons).astype( - np.float32 - ) - }, - "is_male": { - time_period: rng.integers(0, 2, size=n_persons).astype(np.float32) - }, + "age": {time_period: rng.integers(18, 80, size=n_persons).astype(np.float32)}, + "is_male": {time_period: rng.integers(0, 2, size=n_persons).astype(np.float32)}, "household_weight": {time_period: np.ones(n_households) * 1000}, "employment_income": { time_period: rng.uniform(0, 100_000, n_persons).astype(np.float32) @@ -71,9 +65,7 @@ def _make_mock_data(n_persons=20, n_households=5, time_period=2024): }, } for var in CPS_RETIREMENT_VARIABLES: - data[var] = { - time_period: rng.uniform(0, 5000, n_persons).astype(np.float32) - } + data[var] = {time_period: rng.uniform(0, 5000, n_persons).astype(np.float32)} return data @@ -137,9 +129,9 @@ class TestConstants: def test_retirement_vars_not_in_imputed(self): """Retirement vars must NOT be in IMPUTED_VARIABLES.""" for var in CPS_RETIREMENT_VARIABLES: - assert ( - var not in IMPUTED_VARIABLES - ), f"{var} should not be in IMPUTED_VARIABLES" + assert var not in IMPUTED_VARIABLES, ( + f"{var} should not be in IMPUTED_VARIABLES" + ) def test_retirement_vars_not_in_overridden(self): for var in CPS_RETIREMENT_VARIABLES: @@ -169,14 +161,12 @@ def test_retirement_predictors_include_demographics(self): def test_income_predictors_in_imputed_variables(self): """All income predictors must be available from PUF QRF.""" for var in RETIREMENT_INCOME_PREDICTORS: - assert ( - var in IMPUTED_VARIABLES - ), f"{var} not in IMPUTED_VARIABLES — won't be in puf_imputations" + assert var in IMPUTED_VARIABLES, ( + f"{var} not in IMPUTED_VARIABLES — won't be in puf_imputations" + ) def test_predictors_are_combined_lists(self): - expected = ( - RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS - ) + expected = RETIREMENT_DEMOGRAPHIC_PREDICTORS + RETIREMENT_INCOME_PREDICTORS assert RETIREMENT_PREDICTORS == expected @@ -268,18 +258,12 @@ def _setup(self): self.puf_imputations = { "employment_income": emp, "self_employment_income": se, - "taxable_interest_income": rng.uniform(0, 5_000, self.n).astype( - np.float32 - ), + "taxable_interest_income": rng.uniform(0, 5_000, self.n).astype(np.float32), "qualified_dividend_income": rng.uniform(0, 3_000, self.n).astype( np.float32 ), - "taxable_pension_income": rng.uniform(0, 20_000, self.n).astype( - np.float32 - ), - "social_security": rng.uniform(0, 15_000, self.n).astype( - np.float32 - ), + "taxable_pension_income": rng.uniform(0, 20_000, self.n).astype(np.float32), + "social_security": rng.uniform(0, 15_000, self.n).astype(np.float32), } self.cps_df = _make_cps_df(self.n, rng) @@ -317,10 +301,7 @@ def _uniform_preds(self, value): def _random_preds(self, low, high, seed=99): rng = np.random.default_rng(seed) return pd.DataFrame( - { - var: rng.uniform(low, high, self.n) - for var in CPS_RETIREMENT_VARIABLES - } + {var: rng.uniform(low, high, self.n) for var in CPS_RETIREMENT_VARIABLES} ) def test_returns_all_retirement_vars(self): @@ -365,27 +346,23 @@ def test_401k_zero_when_no_wages(self): "traditional_401k_contributions", "roth_401k_contributions", ): - assert np.all( - result[var][zero_wage] == 0 - ), f"{var} should be 0 when employment_income is 0" + assert np.all(result[var][zero_wage] == 0), ( + f"{var} should be 0 when employment_income is 0" + ) def test_se_pension_zero_when_no_se_income(self): result = self._call_with_mocks(self._uniform_preds(5_000.0)) zero_se = self.puf_imputations["self_employment_income"] == 0 assert zero_se.sum() == 20 - assert np.all( - result["self_employed_pension_contributions"][zero_se] == 0 - ) + assert np.all(result["self_employed_pension_contributions"][zero_se] == 0) def test_catch_up_age_threshold(self): """Records age >= 50 get higher caps than younger.""" - self.cps_df["age"] = np.concatenate( - [np.full(25, 30.0), np.full(25, 55.0)] - ) + self.cps_df["age"] = np.concatenate([np.full(25, 30.0), np.full(25, 55.0)]) # All have positive income - self.puf_imputations["employment_income"] = np.full( - self.n, 100_000.0 - ).astype(np.float32) + self.puf_imputations["employment_income"] = np.full(self.n, 100_000.0).astype( + np.float32 + ) lim = _get_retirement_limits(self.time_period) val = float(lim["401k"]) + 1000 # 24000 @@ -402,9 +379,7 @@ def test_catch_up_age_threshold(self): def test_ira_catch_up_threshold(self): """IRA catch-up also works for age >= 50.""" - self.cps_df["age"] = np.concatenate( - [np.full(25, 30.0), np.full(25, 55.0)] - ) + self.cps_df["age"] = np.concatenate([np.full(25, 30.0), np.full(25, 55.0)]) lim = _get_retirement_limits(self.time_period) val = float(lim["ira"]) + 500 # 7500 @@ -430,9 +405,7 @@ def test_401k_nonzero_for_positive_wages(self): def test_se_pension_nonzero_for_positive_se(self): result = self._call_with_mocks(self._uniform_preds(5_000.0)) pos_se = self.puf_imputations["self_employment_income"] > 0 - assert np.all( - result["self_employed_pension_contributions"][pos_se] > 0 - ) + assert np.all(result["self_employed_pension_contributions"][pos_se] > 0) def test_se_pension_capped_at_rate_times_income(self): """SE pension should not exceed 25% of SE income.""" @@ -458,9 +431,7 @@ def test_qrf_failure_returns_zeros(self): # Make a QRF that crashes on fit_predict mock_qrf_cls = MagicMock() - mock_qrf_cls.return_value.fit_predict.side_effect = RuntimeError( - "QRF exploded" - ) + mock_qrf_cls.return_value.fit_predict.side_effect = RuntimeError("QRF exploded") qrf_mod = sys.modules["microimpute.models.qrf"] old_qrf = getattr(qrf_mod, "QRF", None) @@ -486,9 +457,7 @@ def test_training_data_failure_returns_zeros(self): import sys mock_sim = MagicMock() - mock_sim.calculate_dataframe.side_effect = ValueError( - "missing variable" - ) + mock_sim.calculate_dataframe.side_effect = ValueError("missing variable") qrf_mod = sys.modules["microimpute.models.qrf"] old_qrf = getattr(qrf_mod, "QRF", None) @@ -538,9 +507,7 @@ def test_retirement_vars_use_imputed_when_available(self): state_fips = np.array([1, 2, 36, 6, 48]) n = 20 - fake_retirement = { - var: np.full(n, 999.0) for var in CPS_RETIREMENT_VARIABLES - } + fake_retirement = {var: np.full(n, 999.0) for var in CPS_RETIREMENT_VARIABLES} with ( patch( @@ -581,12 +548,8 @@ def test_cps_half_unchanged_with_imputation(self): state_fips = np.array([1, 2, 36, 6, 48]) n = 20 - originals = { - var: data[var][2024].copy() for var in CPS_RETIREMENT_VARIABLES - } - fake_retirement = { - var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES - } + originals = {var: data[var][2024].copy() for var in CPS_RETIREMENT_VARIABLES} + fake_retirement = {var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES} with ( patch( @@ -617,9 +580,7 @@ def test_cps_half_unchanged_with_imputation(self): ) for var in CPS_RETIREMENT_VARIABLES: - np.testing.assert_array_equal( - result[var][2024][:n], originals[var] - ) + np.testing.assert_array_equal(result[var][2024][:n], originals[var]) def test_puf_half_gets_zero_retirement_for_zero_imputed(self): """When imputation returns zeros, PUF half should be zero.""" @@ -627,9 +588,7 @@ def test_puf_half_gets_zero_retirement_for_zero_imputed(self): state_fips = np.array([1, 2, 36, 6, 48]) n = 20 - fake_retirement = { - var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES - } + fake_retirement = {var: np.zeros(n) for var in CPS_RETIREMENT_VARIABLES} with ( patch( @@ -699,6 +658,6 @@ def test_401k_ira_from_policyengine_us(self): ours = _get_retirement_limits(year) pe = pe_limits(year) for key in ["401k", "401k_catch_up", "ira", "ira_catch_up"]: - assert ( - ours[key] == pe[key] - ), f"Year {year} key {key}: {ours[key]} != {pe[key]}" + assert ours[key] == pe[key], ( + f"Year {year} key {key}: {ours[key]} != {pe[key]}" + ) diff --git a/policyengine_us_data/tests/test_calibration/test_source_impute.py b/policyengine_us_data/tests/test_calibration/test_source_impute.py index c69ec653..517a559e 100644 --- a/policyengine_us_data/tests/test_calibration/test_source_impute.py +++ b/policyengine_us_data/tests/test_calibration/test_source_impute.py @@ -71,9 +71,7 @@ def test_scf_variables_defined(self): def test_all_source_variables_defined(self): expected = ( - ACS_IMPUTED_VARIABLES - + SIPP_IMPUTED_VARIABLES - + SCF_IMPUTED_VARIABLES + ACS_IMPUTED_VARIABLES + SIPP_IMPUTED_VARIABLES + SCF_IMPUTED_VARIABLES ) assert ALL_SOURCE_VARIABLES == expected diff --git a/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py index b4f4831d..339dec4e 100644 --- a/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py @@ -40,9 +40,7 @@ def _make_geography(n_hh, cds): ], dtype="U15", ) - state_fips_arr = np.array( - [int(cd) // 100 for cd in cd_geoid], dtype=np.int32 - ) + state_fips_arr = np.array([int(cd) // 100 for cd in cd_geoid], dtype=np.int32) county_fips = np.array([b[:5] for b in block_geoid], dtype="U5") return GeographyAssignment( block_geoid=block_geoid, @@ -154,17 +152,15 @@ def test_counties_match_state(self, stacked_result): state_fips = row["state_fips"] if state_fips == 37: - assert county.endswith( - "_NC" - ), f"NC county should end with _NC: {county}" + assert county.endswith("_NC"), ( + f"NC county should end with _NC: {county}" + ) elif state_fips == 2: - assert county.endswith( - "_AK" - ), f"AK county should end with _AK: {county}" + assert county.endswith("_AK"), ( + f"AK county should end with _AK: {county}" + ) - def test_household_count_matches_weights( - self, stacked_result, test_weights - ): + def test_household_count_matches_weights(self, stacked_result, test_weights): """Number of output households should match non-zero weights.""" hh_df = stacked_result["hh_df"] expected_households = (test_weights > 0).sum() @@ -222,40 +218,30 @@ class TestEntityReindexing: def test_family_ids_are_unique(self, stacked_sim): """Family IDs should be globally unique across all CDs.""" family_ids = stacked_sim.calculate("family_id", map_to="family").values - assert len(family_ids) == len( - set(family_ids) - ), "Family IDs should be unique" + assert len(family_ids) == len(set(family_ids)), "Family IDs should be unique" def test_tax_unit_ids_are_unique(self, stacked_sim): """Tax unit IDs should be globally unique.""" - tax_unit_ids = stacked_sim.calculate( - "tax_unit_id", map_to="tax_unit" - ).values - assert len(tax_unit_ids) == len( - set(tax_unit_ids) - ), "Tax unit IDs should be unique" + tax_unit_ids = stacked_sim.calculate("tax_unit_id", map_to="tax_unit").values + assert len(tax_unit_ids) == len(set(tax_unit_ids)), ( + "Tax unit IDs should be unique" + ) def test_spm_unit_ids_are_unique(self, stacked_sim): """SPM unit IDs should be globally unique.""" - spm_unit_ids = stacked_sim.calculate( - "spm_unit_id", map_to="spm_unit" - ).values - assert len(spm_unit_ids) == len( - set(spm_unit_ids) - ), "SPM unit IDs should be unique" + spm_unit_ids = stacked_sim.calculate("spm_unit_id", map_to="spm_unit").values + assert len(spm_unit_ids) == len(set(spm_unit_ids)), ( + "SPM unit IDs should be unique" + ) def test_person_family_id_matches_family_id(self, stacked_sim): """person_family_id should reference valid family_ids.""" person_family_ids = stacked_sim.calculate( "person_family_id", map_to="person" ).values - family_ids = set( - stacked_sim.calculate("family_id", map_to="family").values - ) + family_ids = set(stacked_sim.calculate("family_id", map_to="family").values) for pf_id in person_family_ids: - assert ( - pf_id in family_ids - ), f"person_family_id {pf_id} not in family_ids" + assert pf_id in family_ids, f"person_family_id {pf_id} not in family_ids" def test_family_ids_unique_across_cds(self, stacked_sim_with_overlap): """Same HH in different CDs should get different family_ids.""" @@ -266,9 +252,9 @@ def test_family_ids_unique_across_cds(self, stacked_sim_with_overlap): family_ids = sim.calculate("family_id", map_to="family").values expected_families = n_overlap * n_cds - assert ( - len(family_ids) == expected_families - ), f"Expected {expected_families} families, got {len(family_ids)}" + assert len(family_ids) == expected_families, ( + f"Expected {expected_families} families, got {len(family_ids)}" + ) assert len(set(family_ids)) == expected_families, ( f"Family IDs not unique: " f"{len(set(family_ids))} unique " diff --git a/policyengine_us_data/tests/test_calibration/test_target_config.py b/policyengine_us_data/tests/test_calibration/test_target_config.py index b19fc94f..377d3a64 100644 --- a/policyengine_us_data/tests/test_calibration/test_target_config.py +++ b/policyengine_us_data/tests/test_calibration/test_target_config.py @@ -104,9 +104,7 @@ def test_domain_variable_matching(self, sample_targets): def test_matrix_and_names_stay_in_sync(self, sample_targets): df, X, names = sample_targets - config = { - "exclude": [{"variable": "person_count", "geo_level": "national"}] - } + config = {"exclude": [{"variable": "person_count", "geo_level": "national"}]} out_df, out_X, out_names = apply_target_config(df, X, names, config) assert out_X.shape[0] == len(out_df) assert len(out_names) == len(out_df) @@ -114,9 +112,7 @@ def test_matrix_and_names_stay_in_sync(self, sample_targets): def test_no_match_keeps_all(self, sample_targets): df, X, names = sample_targets - config = { - "exclude": [{"variable": "nonexistent", "geo_level": "national"}] - } + config = {"exclude": [{"variable": "nonexistent", "geo_level": "national"}]} out_df, out_X, out_names = apply_target_config(df, X, names, config) assert len(out_df) == len(df) assert out_X.shape[0] == X.shape[0] diff --git a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py index d182db5a..28a3c906 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py @@ -78,12 +78,8 @@ class TestBlockSaltedDraws: def test_same_block_same_results(self): blocks = np.array(["370010001001001"] * 500) - d1 = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks - ) - d2 = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks - ) + d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) + d2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) np.testing.assert_array_equal(d1, d2) def test_different_blocks_different_results(self): @@ -102,12 +98,8 @@ def test_different_blocks_different_results(self): def test_different_vars_different_results(self): blocks = np.array(["370010001001001"] * 500) - d1 = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks - ) - d2 = compute_block_takeup_for_entities( - "takes_up_aca_if_eligible", 0.8, blocks - ) + d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) + d2 = compute_block_takeup_for_entities("takes_up_aca_if_eligible", 0.8, blocks) assert not np.array_equal(d1, d2) def test_hh_salt_differs_from_block_only(self): @@ -315,9 +307,7 @@ class TestGeographyAssignmentCountyFips: """Verify county_fips field on GeographyAssignment.""" def test_county_fips_equals_block_prefix(self): - blocks = np.array( - ["370010001001001", "480010002002002", "060370003003003"] - ) + blocks = np.array(["370010001001001", "480010002002002", "060370003003003"]) ga = GeographyAssignment( block_geoid=blocks, cd_geoid=np.array(["3701", "4801", "0613"]), @@ -350,12 +340,8 @@ class TestBlockTakeupSeeding: def test_reproducible(self): blocks = np.array(["010010001001001"] * 50 + ["020010001001001"] * 50) - r1 = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks - ) - r2 = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks - ) + r1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) + r2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) np.testing.assert_array_equal(r1, r2) def test_different_blocks_different_draws(self): @@ -556,17 +542,13 @@ def test_state_specific_rate_resolved_from_block(self): n = 5000 blocks_nc = np.array(["370010001001001"] * n) - result_nc = compute_block_takeup_for_entities( - var, rate_dict, blocks_nc - ) + result_nc = compute_block_takeup_for_entities(var, rate_dict, blocks_nc) # NC rate=0.9, expect ~90% frac_nc = result_nc.mean() assert 0.85 < frac_nc < 0.95, f"NC frac={frac_nc}" blocks_tx = np.array(["480010002002002"] * n) - result_tx = compute_block_takeup_for_entities( - var, rate_dict, blocks_tx - ) + result_tx = compute_block_takeup_for_entities(var, rate_dict, blocks_tx) # TX rate=0.6, expect ~60% frac_tx = result_tx.mean() assert 0.55 < frac_tx < 0.65, f"TX frac={frac_tx}" diff --git a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py index c8588b78..dbc76fb1 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py @@ -130,9 +130,7 @@ def _insert_aca_ptc_data(engine): ] for tid, sid, var, val, period in targets: conn.execute( - text( - "INSERT INTO targets VALUES (:tid, :sid, :var, :val, :period, 1)" - ), + text("INSERT INTO targets VALUES (:tid, :sid, :var, :val, :period, 1)"), { "tid": tid, "sid": sid, @@ -191,9 +189,7 @@ def test_geographic_id_populated(self): df = b._query_targets({"domain_variables": ["aca_ptc"]}) national = df[df["geo_level"] == "national"] self.assertTrue((national["geographic_id"] == "US").all()) - state_ca = df[ - (df["geo_level"] == "state") & (df["geographic_id"] == "6") - ] + state_ca = df[(df["geo_level"] == "state") & (df["geographic_id"] == "6")] self.assertGreater(len(state_ca), 0) @@ -225,9 +221,9 @@ def _get_targets_with_uprating(self, cpi_factor=1.1, pop_factor=1.02): } df["original_value"] = df["value"].copy() df["uprating_factor"] = df.apply( - lambda row: b._get_uprating_info( - row["variable"], row["period"], factors - )[0], + lambda row: b._get_uprating_info(row["variable"], row["period"], factors)[ + 0 + ], axis=1, ) df["value"] = df["original_value"] * df["uprating_factor"] @@ -252,9 +248,7 @@ def test_cd_sums_match_uprated_state(self): & (result["geo_level"] == "district") & ( result["geographic_id"].apply( - lambda g, s=sf: ( - int(g) // 100 == s if g.isdigit() else False - ) + lambda g, s=sf: int(g) // 100 == s if g.isdigit() else False ) ) ] @@ -288,8 +282,7 @@ def test_hif_is_one_when_cds_sum_to_state(self): b, df, factors = self._get_targets_with_uprating(cpi_factor=1.15) result = b._apply_hierarchical_uprating(df, ["aca_ptc"], factors) cd_aca = result[ - (result["variable"] == "aca_ptc") - & (result["geo_level"] == "district") + (result["variable"] == "aca_ptc") & (result["geo_level"] == "district") ] for _, row in cd_aca.iterrows(): self.assertAlmostEqual(row["hif"], 1.0, places=6) @@ -561,18 +554,14 @@ def test_state_fips_set_correctly(self, mock_msim_cls, mock_gcv): ) # First sim should get state 37 - fips_calls_0 = [ - c for c in sims[0].set_input_calls if c[0] == "state_fips" - ] + fips_calls_0 = [c for c in sims[0].set_input_calls if c[0] == "state_fips"] assert len(fips_calls_0) == 1 np.testing.assert_array_equal( fips_calls_0[0][2], np.full(4, 37, dtype=np.int32) ) # Second sim should get state 48 - fips_calls_1 = [ - c for c in sims[1].set_input_calls if c[0] == "state_fips" - ] + fips_calls_1 = [c for c in sims[1].set_input_calls if c[0] == "state_fips"] assert len(fips_calls_1) == 1 np.testing.assert_array_equal( fips_calls_1[0][2], np.full(4, 48, dtype=np.int32) @@ -613,9 +602,9 @@ def test_takeup_vars_forced_true(self, mock_msim_cls, mock_gcv): assert values.all(), f"{var} not forced True" set_true_vars.add(var) - assert ( - takeup_var_names == set_true_vars - ), f"Missing forced-true vars: {takeup_var_names - set_true_vars}" + assert takeup_var_names == set_true_vars, ( + f"Missing forced-true vars: {takeup_var_names - set_true_vars}" + ) # Entity-level calculation happens for affected target entity_calcs = [ @@ -738,9 +727,7 @@ def test_return_structure(self, mock_msim_cls, mock_gcv, mock_county_idx): return_value=["var_a"], ) @patch("policyengine_us.Microsimulation") - def test_sim_reuse_within_state( - self, mock_msim_cls, mock_gcv, mock_county_idx - ): + def test_sim_reuse_within_state(self, mock_msim_cls, mock_gcv, mock_county_idx): sim = _FakeSimulation() mock_msim_cls.return_value = sim @@ -771,9 +758,7 @@ def test_sim_reuse_within_state( return_value=[], ) @patch("policyengine_us.Microsimulation") - def test_fresh_sim_across_states( - self, mock_msim_cls, mock_gcv, mock_county_idx - ): + def test_fresh_sim_across_states(self, mock_msim_cls, mock_gcv, mock_county_idx): mock_msim_cls.side_effect = [ _FakeSimulation(), _FakeSimulation(), @@ -802,9 +787,7 @@ def test_fresh_sim_across_states( return_value=["var_a", "county"], ) @patch("policyengine_us.Microsimulation") - def test_delete_arrays_per_county( - self, mock_msim_cls, mock_gcv, mock_county_idx - ): + def test_delete_arrays_per_county(self, mock_msim_cls, mock_gcv, mock_county_idx): sim = _FakeSimulation() mock_msim_cls.return_value = sim @@ -879,9 +862,7 @@ def _make_geo(self, states, n_records=4): return_value=[], ) @patch("policyengine_us.Microsimulation") - def test_workers_gt1_creates_pool( - self, mock_msim_cls, mock_gcv, mock_pool_cls - ): + def test_workers_gt1_creates_pool(self, mock_msim_cls, mock_gcv, mock_pool_cls): mock_future = MagicMock() mock_future.result.return_value = ( 37, @@ -1012,9 +993,7 @@ def test_workers_gt1_creates_pool( return_value=[], ) @patch("policyengine_us.Microsimulation") - def test_workers_1_skips_pool( - self, mock_msim_cls, mock_gcv, mock_county_idx - ): + def test_workers_1_skips_pool(self, mock_msim_cls, mock_gcv, mock_county_idx): mock_msim_cls.return_value = _FakeSimulation() builder = self._make_builder() geo = self._make_geo(["37001"]) diff --git a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py index 78ea4723..403fe1af 100644 --- a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py +++ b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py @@ -17,9 +17,7 @@ from policyengine_us_data.storage import STORAGE_FOLDER -DATASET_PATH = str( - STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" -) +DATASET_PATH = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") DB_PATH = str(STORAGE_FOLDER / "calibration" / "policy_data.db") DB_URI = f"sqlite:///{DB_PATH}" @@ -101,9 +99,7 @@ def test_xw_matches_stacked_sim(): for i, cd in enumerate(cds_ordered): mask = geography.cd_geoid.astype(str) == cd cd_weights[cd] = w[mask].sum() - top_cds = sorted(cd_weights, key=cd_weights.get, reverse=True)[ - :N_CDS_TO_CHECK - ] + top_cds = sorted(cd_weights, key=cd_weights.get, reverse=True)[:N_CDS_TO_CHECK] check_vars = ["aca_ptc", "snap"] tmpdir = tempfile.mkdtemp() @@ -129,8 +125,7 @@ def test_xw_matches_stacked_sim(): stacked_sum = (vals * hh_weight).sum() cd_row = targets_df[ - (targets_df["variable"] == var) - & (targets_df["geographic_id"] == cd) + (targets_df["variable"] == var) & (targets_df["geographic_id"] == cd) ] if len(cd_row) == 0: continue diff --git a/policyengine_us_data/tests/test_constraint_validation.py b/policyengine_us_data/tests/test_constraint_validation.py index 29920475..e494f5c9 100644 --- a/policyengine_us_data/tests/test_constraint_validation.py +++ b/policyengine_us_data/tests/test_constraint_validation.py @@ -138,9 +138,7 @@ def test_conflicting_lower_bounds(self): Constraint(variable="age", operation=">", value="20"), Constraint(variable="age", operation=">=", value="25"), ] - with pytest.raises( - ConstraintValidationError, match="conflicting lower bounds" - ): + with pytest.raises(ConstraintValidationError, match="conflicting lower bounds"): ensure_consistent_constraint_set(constraints) def test_conflicting_upper_bounds(self): @@ -149,9 +147,7 @@ def test_conflicting_upper_bounds(self): Constraint(variable="age", operation="<", value="50"), Constraint(variable="age", operation="<=", value="45"), ] - with pytest.raises( - ConstraintValidationError, match="conflicting upper bounds" - ): + with pytest.raises(ConstraintValidationError, match="conflicting upper bounds"): ensure_consistent_constraint_set(constraints) @@ -193,9 +189,7 @@ class TestNonNumericValues: def test_string_equality_valid(self): """medicaid_enrolled == 'True' should pass.""" constraints = [ - Constraint( - variable="medicaid_enrolled", operation="==", value="True" - ), + Constraint(variable="medicaid_enrolled", operation="==", value="True"), ] ensure_consistent_constraint_set(constraints) # No exception diff --git a/policyengine_us_data/tests/test_database_build.py b/policyengine_us_data/tests/test_database_build.py index 0bdcdeb7..87a6ce08 100644 --- a/policyengine_us_data/tests/test_database_build.py +++ b/policyengine_us_data/tests/test_database_build.py @@ -22,7 +22,9 @@ # HuggingFace URL for the stratified CPS dataset. # ETL scripts use this only to derive the time period (2024). -HF_DATASET = "hf://policyengine/policyengine-us-data/calibration/stratified_extended_cps.h5" +HF_DATASET = ( + "hf://policyengine/policyengine-us-data/calibration/stratified_extended_cps.h5" +) # Scripts run in the same order as `make database` in the Makefile. # create_database_tables.py does not use etl_argparser. @@ -77,9 +79,7 @@ def built_db(): ) if errors: - pytest.fail( - f"{len(errors)} ETL script(s) failed:\n" + "\n\n".join(errors) - ) + pytest.fail(f"{len(errors)} ETL script(s) failed:\n" + "\n\n".join(errors)) assert DB_PATH.exists(), "policy_data.db was not created" return DB_PATH @@ -96,9 +96,7 @@ def test_expected_tables_exist(built_db): conn = sqlite3.connect(str(built_db)) tables = { row[0] - for row in conn.execute( - "SELECT name FROM sqlite_master WHERE type='table'" - ) + for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'") } conn.close() @@ -122,9 +120,9 @@ def test_national_targets_loaded(built_db): variables = {r[0] for r in rows} for expected in ["snap", "social_security", "ssi"]: - assert ( - expected in variables - ), f"National target '{expected}' missing. Found: {sorted(variables)}" + assert expected in variables, ( + f"National target '{expected}' missing. Found: {sorted(variables)}" + ) def test_state_income_tax_targets(built_db): @@ -148,9 +146,9 @@ def test_state_income_tax_targets(built_db): # California should be the largest, over $100B. ca_val = state_totals.get("06") or state_totals.get("6") assert ca_val is not None, "California (FIPS 06) target missing" - assert ( - ca_val > 100e9 - ), f"California income tax should be > $100B, got ${ca_val / 1e9:.1f}B" + assert ca_val > 100e9, ( + f"California income tax should be > $100B, got ${ca_val / 1e9:.1f}B" + ) def test_congressional_district_strata(built_db): @@ -171,9 +169,7 @@ def test_all_target_variables_exist_in_policyengine(built_db): from policyengine_us.system import system conn = sqlite3.connect(str(built_db)) - variables = { - r[0] for r in conn.execute("SELECT DISTINCT variable FROM targets") - } + variables = {r[0] for r in conn.execute("SELECT DISTINCT variable FROM targets")} conn.close() missing = [v for v in variables if v not in system.variables] diff --git a/policyengine_us_data/tests/test_datasets/test_county_fips.py b/policyengine_us_data/tests/test_datasets/test_county_fips.py index 0414aa55..ac2eb9fa 100644 --- a/policyengine_us_data/tests/test_datasets/test_county_fips.py +++ b/policyengine_us_data/tests/test_datasets/test_county_fips.py @@ -48,9 +48,7 @@ def mock_upload_to_hf(): def mock_local_folder(): """Mock the LOCAL_FOLDER""" mock_path = MagicMock() - with patch( - "policyengine_us_data.geography.county_fips.LOCAL_FOLDER", mock_path - ): + with patch("policyengine_us_data.geography.county_fips.LOCAL_FOLDER", mock_path): yield mock_path @@ -179,6 +177,4 @@ def test_huggingface_upload(mock_upload_to_hf, mock_to_csv, mock_requests_get): assert call_kwargs["repo_file_path"] == "county_fips_2020.csv.gz" # Verify that the first parameter is a BytesIO instance - assert isinstance( - mock_upload_to_hf.call_args[1]["local_file_path"], BytesIO - ) + assert isinstance(mock_upload_to_hf.call_args[1]["local_file_path"], BytesIO) diff --git a/policyengine_us_data/tests/test_datasets/test_cps.py b/policyengine_us_data/tests/test_datasets/test_cps.py index bbfba73b..f0346939 100644 --- a/policyengine_us_data/tests/test_datasets/test_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_cps.py @@ -13,18 +13,11 @@ def test_cps_has_auto_loan_interest(): RELATIVE_TOLERANCE = 0.4 assert ( - abs( - sim.calculate("auto_loan_interest").sum() - / AUTO_LOAN_INTEREST_TARGET - - 1 - ) + abs(sim.calculate("auto_loan_interest").sum() / AUTO_LOAN_INTEREST_TARGET - 1) < RELATIVE_TOLERANCE ) assert ( - abs( - sim.calculate("auto_loan_balance").sum() / AUTO_LOAN_BALANCE_TARGET - - 1 - ) + abs(sim.calculate("auto_loan_balance").sum() / AUTO_LOAN_BALANCE_TARGET - 1) < RELATIVE_TOLERANCE ) @@ -38,11 +31,7 @@ def test_cps_has_fsla_overtime_premium(): OVERTIME_PREMIUM_TARGET = 70e9 RELATIVE_TOLERANCE = 0.2 assert ( - abs( - sim.calculate("fsla_overtime_premium").sum() - / OVERTIME_PREMIUM_TARGET - - 1 - ) + abs(sim.calculate("fsla_overtime_premium").sum() / OVERTIME_PREMIUM_TARGET - 1) < RELATIVE_TOLERANCE ) diff --git a/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py b/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py index 8314fe7f..4e8732b0 100644 --- a/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py +++ b/policyengine_us_data/tests/test_datasets/test_dataset_sanity.py @@ -41,27 +41,23 @@ def test_ecps_employment_income_positive(ecps_sim): def test_ecps_self_employment_income_positive(ecps_sim): total = ecps_sim.calculate("self_employment_income").sum() - assert ( - total > 50e9 - ), f"self_employment_income sum is {total:.2e}, expected > 50B." + assert total > 50e9, f"self_employment_income sum is {total:.2e}, expected > 50B." def test_ecps_household_count(ecps_sim): """Household count should be roughly 130-160M.""" total_hh = ecps_sim.calculate("household_weight").values.sum() - assert ( - 100e6 < total_hh < 200e6 - ), f"Total households = {total_hh:.2e}, expected 100M-200M." + assert 100e6 < total_hh < 200e6, ( + f"Total households = {total_hh:.2e}, expected 100M-200M." + ) def test_ecps_person_count(ecps_sim): """Weighted person count should be roughly 330M.""" - total_people = ecps_sim.calculate( - "household_weight", map_to="person" - ).values.sum() - assert ( - 250e6 < total_people < 400e6 - ), f"Total people = {total_people:.2e}, expected 250M-400M." + total_people = ecps_sim.calculate("household_weight", map_to="person").values.sum() + assert 250e6 < total_people < 400e6, ( + f"Total people = {total_people:.2e}, expected 250M-400M." + ) def test_ecps_poverty_rate_reasonable(ecps_sim): @@ -84,9 +80,9 @@ def test_ecps_mean_employment_income_reasonable(ecps_sim): """Mean employment income per person should be $20k-$60k.""" income = ecps_sim.calculate("employment_income", map_to="person") mean = income.mean() - assert ( - 15_000 < mean < 80_000 - ), f"Mean employment income = ${mean:,.0f}, expected $15k-$80k." + assert 15_000 < mean < 80_000, ( + f"Mean employment income = ${mean:,.0f}, expected $15k-$80k." + ) # ── CPS sanity checks ─────────────────────────────────────────── @@ -94,9 +90,7 @@ def test_ecps_mean_employment_income_reasonable(ecps_sim): def test_cps_employment_income_positive(cps_sim): total = cps_sim.calculate("employment_income").sum() - assert ( - total > 5e12 - ), f"CPS employment_income sum is {total:.2e}, expected > 5T." + assert total > 5e12, f"CPS employment_income sum is {total:.2e}, expected > 5T." def test_cps_household_count(cps_sim): @@ -122,24 +116,20 @@ def sparse_sim(): def test_sparse_employment_income_positive(sparse_sim): """Sparse dataset employment income must be in the trillions.""" total = sparse_sim.calculate("employment_income").sum() - assert ( - total > 5e12 - ), f"Sparse employment_income sum is {total:.2e}, expected > 5T." + assert total > 5e12, f"Sparse employment_income sum is {total:.2e}, expected > 5T." def test_sparse_household_count(sparse_sim): total_hh = sparse_sim.calculate("household_weight").values.sum() - assert ( - 100e6 < total_hh < 200e6 - ), f"Sparse total households = {total_hh:.2e}, expected 100M-200M." + assert 100e6 < total_hh < 200e6, ( + f"Sparse total households = {total_hh:.2e}, expected 100M-200M." + ) def test_sparse_poverty_rate_reasonable(sparse_sim): in_poverty = sparse_sim.calculate("person_in_poverty", map_to="person") rate = in_poverty.mean() - assert ( - 0.05 < rate < 0.30 - ), f"Sparse poverty rate = {rate:.1%}, expected 5-30%." + assert 0.05 < rate < 0.30, f"Sparse poverty rate = {rate:.1%}, expected 5-30%." # ── File size checks ─────────────────────────────────────────── @@ -153,6 +143,6 @@ def test_ecps_file_size(): if not path.exists(): pytest.skip("enhanced_cps_2024.h5 not found") size_mb = path.stat().st_size / (1024 * 1024) - assert ( - size_mb > 100 - ), f"enhanced_cps_2024.h5 is only {size_mb:.1f}MB, expected >100MB" + assert size_mb > 100, ( + f"enhanced_cps_2024.h5 is only {size_mb:.1f}MB, expected >100MB" + ) diff --git a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py index 4c79874e..298de5a4 100644 --- a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py @@ -50,10 +50,10 @@ def test_ecps_replicates_jct_tax_expenditures(): & (calibration_log["epoch"] == calibration_log["epoch"].max()) ] - assert ( - jct_rows.rel_abs_error.max() < 0.5 - ), "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( - jct_rows.rel_abs_error.max() + assert jct_rows.rel_abs_error.max() < 0.5, ( + "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( + jct_rows.rel_abs_error.max() + ) ) @@ -71,9 +71,7 @@ def deprecated_test_ecps_replicates_jct_tax_expenditures_full(): } baseline = Microsimulation(dataset=EnhancedCPS_2024) - income_tax_b = baseline.calculate( - "income_tax", period=2024, map_to="household" - ) + income_tax_b = baseline.calculate("income_tax", period=2024, map_to="household") for deduction, target in EXPENDITURE_TARGETS.items(): # Create reform that neutralizes the deduction @@ -82,12 +80,8 @@ def apply(self): self.neutralize_variable(deduction) # Run reform simulation - reformed = Microsimulation( - reform=RepealDeduction, dataset=EnhancedCPS_2024 - ) - income_tax_r = reformed.calculate( - "income_tax", period=2024, map_to="household" - ) + reformed = Microsimulation(reform=RepealDeduction, dataset=EnhancedCPS_2024) + income_tax_r = reformed.calculate("income_tax", period=2024, map_to="household") # Calculate tax expenditure tax_expenditure = (income_tax_r - income_tax_b).sum() @@ -137,9 +131,9 @@ def test_undocumented_matches_ssn_none(): # 1. Per-person equivalence mismatches = np.where(ssn_type_none_mask != undocumented_mask)[0] - assert ( - mismatches.size == 0 - ), f"{mismatches.size} mismatches between 'NONE' SSN and 'UNDOCUMENTED' status" + assert mismatches.size == 0, ( + f"{mismatches.size} mismatches between 'NONE' SSN and 'UNDOCUMENTED' status" + ) # 2. Optional aggregate sanity-check count = undocumented_mask.sum() @@ -164,9 +158,7 @@ def test_aca_calibration(): # Monthly to yearly targets["spending"] = targets["spending"] * 12 # Adjust to match national target - targets["spending"] = targets["spending"] * ( - 98e9 / targets["spending"].sum() - ) + targets["spending"] = targets["spending"] * (98e9 / targets["spending"].sum()) sim = Microsimulation(dataset=EnhancedCPS_2024) state_code_hh = sim.calculate("state_code", map_to="household").values @@ -189,9 +181,7 @@ def test_aca_calibration(): if pct_error > TOLERANCE: failed = True - assert ( - not failed - ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." def test_immigration_status_diversity(): @@ -227,20 +217,18 @@ def test_immigration_status_diversity(): ) # Also check that we have a reasonable percentage of citizens (should be 85-90%) - assert ( - 80 < citizen_pct < 95 - ), f"Citizen percentage ({citizen_pct:.1f}%) outside expected range (80-95%)" + assert 80 < citizen_pct < 95, ( + f"Citizen percentage ({citizen_pct:.1f}%) outside expected range (80-95%)" + ) # Check that we have some non-citizens non_citizen_pct = 100 - citizen_pct - assert ( - non_citizen_pct > 5 - ), f"Too few non-citizens ({non_citizen_pct:.1f}%) - expected at least 5%" - - print( - f"Immigration status diversity test passed: {citizen_pct:.1f}% citizens" + assert non_citizen_pct > 5, ( + f"Too few non-citizens ({non_citizen_pct:.1f}%) - expected at least 5%" ) + print(f"Immigration status diversity test passed: {citizen_pct:.1f}% citizens") + def test_medicaid_calibration(): @@ -277,6 +265,4 @@ def test_medicaid_calibration(): if pct_error > TOLERANCE: failed = True - assert ( - not failed - ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." diff --git a/policyengine_us_data/tests/test_datasets/test_sipp_assets.py b/policyengine_us_data/tests/test_datasets/test_sipp_assets.py index 0f839a9c..a79b4bce 100644 --- a/policyengine_us_data/tests/test_datasets/test_sipp_assets.py +++ b/policyengine_us_data/tests/test_datasets/test_sipp_assets.py @@ -101,12 +101,12 @@ def test_liquid_assets_distribution(): MEDIAN_MIN = 3_000 MEDIAN_MAX = 20_000 - assert ( - weighted_median > MEDIAN_MIN - ), f"Median liquid assets ${weighted_median:,.0f} below minimum ${MEDIAN_MIN:,}" - assert ( - weighted_median < MEDIAN_MAX - ), f"Median liquid assets ${weighted_median:,.0f} above maximum ${MEDIAN_MAX:,}" + assert weighted_median > MEDIAN_MIN, ( + f"Median liquid assets ${weighted_median:,.0f} below minimum ${MEDIAN_MIN:,}" + ) + assert weighted_median < MEDIAN_MAX, ( + f"Median liquid assets ${weighted_median:,.0f} above maximum ${MEDIAN_MAX:,}" + ) def test_asset_categories_exist(): @@ -127,9 +127,7 @@ def test_asset_categories_exist(): assert bonds >= 0, "Bond assets should be non-negative" # Bank accounts typically largest category of liquid assets - assert ( - bank > stocks * 0.3 - ), "Bank accounts should be substantial relative to stocks" + assert bank > stocks * 0.3, "Bank accounts should be substantial relative to stocks" def test_low_asset_households(): @@ -155,9 +153,9 @@ def test_low_asset_households(): MIN_PCT = 0.10 MAX_PCT = 0.70 - assert ( - below_2k > MIN_PCT - ), f"Only {below_2k:.1%} have <$2k liquid assets, expected at least {MIN_PCT:.0%}" - assert ( - below_2k < MAX_PCT - ), f"{below_2k:.1%} have <$2k liquid assets, expected at most {MAX_PCT:.0%}" + assert below_2k > MIN_PCT, ( + f"Only {below_2k:.1%} have <$2k liquid assets, expected at least {MIN_PCT:.0%}" + ) + assert below_2k < MAX_PCT, ( + f"{below_2k:.1%} have <$2k liquid assets, expected at most {MAX_PCT:.0%}" + ) diff --git a/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py index 23b7b2dc..9316d390 100644 --- a/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_small_enhanced_cps.py @@ -19,12 +19,10 @@ def test_small_ecps_loads(year: int): # Employment income should be positive (not zero from missing vars) emp_income = sim.calculate("employment_income", 2025).sum() - assert ( - emp_income > 0 - ), f"Small ECPS employment_income sum is {emp_income}, expected > 0." + assert emp_income > 0, ( + f"Small ECPS employment_income sum is {emp_income}, expected > 0." + ) # Should have a reasonable number of households hh_count = len(sim.calculate("household_net_income", 2025)) - assert ( - hh_count > 100 - ), f"Small ECPS has only {hh_count} households, expected > 100." + assert hh_count > 100, f"Small ECPS has only {hh_count} households, expected > 100." diff --git a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py index bea1e3b3..a7ee941b 100644 --- a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py @@ -115,10 +115,10 @@ def test_sparse_ecps_replicates_jct_tax_expenditures(): & (calibration_log["epoch"] == calibration_log["epoch"].max()) ] - assert ( - jct_rows.rel_abs_error.max() < 0.5 - ), "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( - jct_rows.rel_abs_error.max() + assert jct_rows.rel_abs_error.max() < 0.5, ( + "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( + jct_rows.rel_abs_error.max() + ) ) @@ -133,9 +133,7 @@ def deprecated_test_sparse_ecps_replicates_jct_tax_expenditures_full(sim): } baseline = sim - income_tax_b = baseline.calculate( - "income_tax", period=2024, map_to="household" - ) + income_tax_b = baseline.calculate("income_tax", period=2024, map_to="household") for deduction, target in EXPENDITURE_TARGETS.items(): # Create reform that neutralizes the deduction @@ -145,9 +143,7 @@ def apply(self): # Run reform simulation reformed = Microsimulation(reform=RepealDeduction, dataset=sim.dataset) - income_tax_r = reformed.calculate( - "income_tax", period=2024, map_to="household" - ) + income_tax_r = reformed.calculate("income_tax", period=2024, map_to="household") # Calculate tax expenditure tax_expenditure = (income_tax_r - income_tax_b).sum() @@ -188,9 +184,7 @@ def test_sparse_aca_calibration(sim): # Monthly to yearly targets["spending"] = targets["spending"] * 12 # Adjust to match national target - targets["spending"] = targets["spending"] * ( - 98e9 / targets["spending"].sum() - ) + targets["spending"] = targets["spending"] * (98e9 / targets["spending"].sum()) state_code_hh = sim.calculate("state_code", map_to="household").values aca_ptc = sim.calculate("aca_ptc", map_to="household", period=2025) @@ -212,9 +206,7 @@ def test_sparse_aca_calibration(sim): if pct_error > TOLERANCE: failed = True - assert ( - not failed - ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." def test_sparse_medicaid_calibration(sim): @@ -246,6 +238,4 @@ def test_sparse_medicaid_calibration(sim): if pct_error > TOLERANCE: failed = True - assert ( - not failed - ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + assert not failed, f"One or more states exceeded tolerance of {TOLERANCE:.0%}." diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py index d5091b8b..722064b5 100644 --- a/policyengine_us_data/tests/test_format_comparison.py +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -87,10 +87,7 @@ def _read_h5py_arrays(h5py_path: str): arr = f[var][period_key][:] if arr.dtype.kind in ("S", "O"): arr = np.array( - [ - x.decode() if isinstance(x, bytes) else str(x) - for x in arr - ] + [x.decode() if isinstance(x, bytes) else str(x) for x in arr] ) # Wrap in nested dict keyed by the period string data[var] = {period_key: arr} @@ -117,9 +114,7 @@ def h5py_to_hdfstore(h5py_path: str, hdfstore_path: str) -> dict: print("Reading h5py file...") data, time_period, h5_vars = _read_h5py_arrays(h5py_path) n_persons = len(next(iter(data.get("person_id", {}).values()), [])) - print( - f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}" - ) + print(f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}") print("Splitting into entity DataFrames...") entity_dfs = split_data_into_entity_dfs(data, system, time_period) @@ -187,11 +182,7 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: hdf_unique = np.unique(hdf_values) if h5_values.dtype.kind in ("U", "S", "O"): match = set( - ( - x.decode() - if isinstance(x, bytes) - else str(x) - ) + (x.decode() if isinstance(x, bytes) else str(x)) for x in h5_unique ) == set(str(x) for x in hdf_unique) else: @@ -217,17 +208,11 @@ def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: if h5_values.dtype.kind in ("U", "S", "O"): h5_str = np.array( [ - ( - x.decode() - if isinstance(x, bytes) - else str(x) - ) + (x.decode() if isinstance(x, bytes) else str(x)) for x in h5_values ] ) - hdf_str = np.array( - [str(x) for x in hdf_values] - ) + hdf_str = np.array([str(x) for x in hdf_values]) if np.array_equal(h5_str, hdf_str): passed.append(var) else: @@ -328,12 +313,12 @@ def test_roundtrip(h5py_path, tmp_path): result = compare_formats(h5py_path, hdfstore_path) print_results(result) - assert ( - len(result["failed"]) == 0 - ), f"{len(result['failed'])} variables have mismatched values" - assert ( - len(result["skipped"]) == 0 - ), f"{len(result['skipped'])} variables missing from HDFStore" + assert len(result["failed"]) == 0, ( + f"{len(result['failed'])} variables have mismatched values" + ) + assert len(result["skipped"]) == 0, ( + f"{len(result['skipped'])} variables missing from HDFStore" + ) def test_manifest(h5py_path, tmp_path): @@ -342,9 +327,7 @@ def test_manifest(h5py_path, tmp_path): h5py_to_hdfstore(h5py_path, hdfstore_path) with pd.HDFStore(hdfstore_path, "r") as store: - assert ( - "/_variable_metadata" in store.keys() - ), "Missing _variable_metadata table" + assert "/_variable_metadata" in store.keys(), "Missing _variable_metadata table" manifest = store["/_variable_metadata"] assert "variable" in manifest.columns assert "entity" in manifest.columns @@ -363,17 +346,15 @@ def test_all_entities(h5py_path, tmp_path): expected = set(ENTITIES) with pd.HDFStore(hdfstore_path, "r") as store: - actual = { - k.lstrip("/") for k in store.keys() if not k.startswith("/_") - } + actual = {k.lstrip("/") for k in store.keys() if not k.startswith("/_")} missing = expected - actual assert not missing, f"Missing entity tables: {missing}" for entity in expected: df = store[f"/{entity}"] assert len(df) > 0, f"Entity {entity} has 0 rows" - assert ( - f"{entity}_id" in df.columns - ), f"Entity {entity} missing {entity}_id column" + assert f"{entity}_id" in df.columns, ( + f"Entity {entity} missing {entity}_id column" + ) print(f" {entity}: {len(df):,} rows, {len(df.columns)} cols") @@ -384,9 +365,7 @@ def test_all_entities(h5py_path, tmp_path): parser = argparse.ArgumentParser( description="Convert h5py dataset to HDFStore and verify roundtrip" ) - parser.add_argument( - "--h5py-path", required=True, help="Path to h5py format file" - ) + parser.add_argument("--h5py-path", required=True, help="Path to h5py format file") parser.add_argument( "--output-path", default=None, diff --git a/policyengine_us_data/tests/test_puf_impute.py b/policyengine_us_data/tests/test_puf_impute.py index fcdcf763..d968fb16 100644 --- a/policyengine_us_data/tests/test_puf_impute.py +++ b/policyengine_us_data/tests/test_puf_impute.py @@ -57,9 +57,7 @@ def _make_data( if age is not None: data["age"] = {tp: np.concatenate([age, age]).astype(np.float32)} if is_male is not None: - data["is_male"] = { - tp: np.concatenate([is_male, is_male]).astype(np.float32) - } + data["is_male"] = {tp: np.concatenate([is_male, is_male]).astype(np.float32)} return data, n, tp diff --git a/policyengine_us_data/tests/test_schema_views_and_lookups.py b/policyengine_us_data/tests/test_schema_views_and_lookups.py index d7495ff3..c8e5f4f8 100644 --- a/policyengine_us_data/tests/test_schema_views_and_lookups.py +++ b/policyengine_us_data/tests/test_schema_views_and_lookups.py @@ -227,9 +227,7 @@ def _query_stratum_domain(self): from sqlalchemy import text with self.engine.connect() as conn: - rows = conn.execute( - text("SELECT * FROM stratum_domain") - ).fetchall() + rows = conn.execute(text("SELECT * FROM stratum_domain")).fetchall() return rows def test_geographic_stratum_excluded(self): @@ -291,18 +289,14 @@ def _query_target_overview(self): from sqlalchemy import text with self.engine.connect() as conn: - rows = conn.execute( - text("SELECT * FROM target_overview") - ).fetchall() + rows = conn.execute(text("SELECT * FROM target_overview")).fetchall() return rows def _overview_columns(self): from sqlalchemy import text with self.engine.connect() as conn: - cursor = conn.execute( - text("SELECT * FROM target_overview LIMIT 0") - ) + cursor = conn.execute(text("SELECT * FROM target_overview LIMIT 0")) return [desc[0] for desc in cursor.cursor.description] def test_national_geo_level(self): diff --git a/policyengine_us_data/utils/census.py b/policyengine_us_data/utils/census.py index c61cc166..422d750c 100644 --- a/policyengine_us_data/utils/census.py +++ b/policyengine_us_data/utils/census.py @@ -139,9 +139,7 @@ def get_census_docs(year): - docs_url = ( - f"https://api.census.gov/data/{year}/acs/acs1/subject/variables.json" - ) + docs_url = f"https://api.census.gov/data/{year}/acs/acs1/subject/variables.json" cache_file = f"census_docs_{year}.json" if is_cached(cache_file): logger.info(f"Using cached {cache_file}") diff --git a/policyengine_us_data/utils/constraint_validation.py b/policyengine_us_data/utils/constraint_validation.py index d3c4305d..f533739c 100644 --- a/policyengine_us_data/utils/constraint_validation.py +++ b/policyengine_us_data/utils/constraint_validation.py @@ -111,9 +111,7 @@ def _check_operation_compatibility(var_name: str, operations: set) -> None: ) -def _check_range_validity( - var_name: str, constraints: List[Constraint] -) -> None: +def _check_range_validity(var_name: str, constraints: List[Constraint]) -> None: """Check that range constraints don't create an empty range.""" lower_bound = float("-inf") upper_bound = float("inf") @@ -128,9 +126,7 @@ def _check_range_validity( continue if c.operation == ">": - if val > lower_bound or ( - val == lower_bound and not lower_inclusive - ): + if val > lower_bound or (val == lower_bound and not lower_inclusive): lower_bound = val lower_inclusive = False elif c.operation == ">=": @@ -138,9 +134,7 @@ def _check_range_validity( lower_bound = val lower_inclusive = True elif c.operation == "<": - if val < upper_bound or ( - val == upper_bound and not upper_inclusive - ): + if val < upper_bound or (val == upper_bound and not upper_inclusive): upper_bound = val upper_inclusive = False elif c.operation == "<=": @@ -154,9 +148,7 @@ def _check_range_validity( f"{var_name}: empty range - lower bound {lower_bound} > " f"upper bound {upper_bound}" ) - if lower_bound == upper_bound and not ( - lower_inclusive and upper_inclusive - ): + if lower_bound == upper_bound and not (lower_inclusive and upper_inclusive): raise ConstraintValidationError( f"{var_name}: empty range - bounds equal at {lower_bound} " "but not both inclusive" diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index e9509837..c8a50036 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -117,18 +117,14 @@ def upload_files_to_gcs( Upload files to Google Cloud Storage and set metadata with the version. """ credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) for file_path in files: file_path = Path(file_path) blob = bucket.blob(file_path.name) blob.upload_from_filename(file_path) - logging.info( - f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}." - ) + logging.info(f"Uploaded {file_path.name} to GCS bucket {gcs_bucket_name}.") # Set metadata blob.metadata = {"version": version} @@ -165,9 +161,7 @@ def upload_local_area_file( # Upload to GCS with subdirectory credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) blob_name = f"{subdirectory}/{file_path.name}" @@ -337,9 +331,7 @@ def upload_to_staging_hf( f"Uploaded batch {i // batch_size + 1}: {len(operations)} files to staging/" ) - logging.info( - f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace" - ) + logging.info(f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace") return total_uploaded @@ -490,9 +482,7 @@ def upload_from_hf_staging_to_gcs( token = os.environ.get("HUGGING_FACE_TOKEN") credentials, project_id = google.auth.default() - storage_client = storage.Client( - credentials=credentials, project=project_id - ) + storage_client = storage.Client(credentials=credentials, project=project_id) bucket = storage_client.bucket(gcs_bucket_name) uploaded = 0 diff --git a/policyengine_us_data/utils/db.py b/policyengine_us_data/utils/db.py index ad0c0669..128dbb78 100644 --- a/policyengine_us_data/utils/db.py +++ b/policyengine_us_data/utils/db.py @@ -11,9 +11,7 @@ ) from policyengine_us_data.storage import STORAGE_FOLDER -DEFAULT_DATASET = str( - STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5" -) +DEFAULT_DATASET = str(STORAGE_FOLDER / "source_imputed_stratified_extended_cps_2024.h5") def etl_argparser( @@ -46,10 +44,7 @@ def etl_argparser( args = parser.parse_args() - if ( - not args.dataset.startswith("hf://") - and not Path(args.dataset).exists() - ): + if not args.dataset.startswith("hf://") and not Path(args.dataset).exists(): raise FileNotFoundError( f"Dataset not found: {args.dataset}\n" f"Either build it locally (`make data`) or pass a " @@ -71,18 +66,14 @@ def get_stratum_by_id(session: Session, stratum_id: int) -> Optional[Stratum]: return session.get(Stratum, stratum_id) -def get_simple_stratum_by_ucgid( - session: Session, ucgid: str -) -> Optional[Stratum]: +def get_simple_stratum_by_ucgid(session: Session, ucgid: str) -> Optional[Stratum]: """ Finds a stratum defined *only* by a single ucgid_str constraint. """ constraint_count_subquery = ( select( StratumConstraint.stratum_id, - sa.func.count(StratumConstraint.stratum_id).label( - "constraint_count" - ), + sa.func.count(StratumConstraint.stratum_id).label("constraint_count"), ) .group_by(StratumConstraint.stratum_id) .subquery() @@ -139,16 +130,12 @@ def parse_ucgid(ucgid_str: str) -> Dict: elif ucgid_str.startswith("0400000US"): state_fips = int(ucgid_str[9:]) return {"type": "state", "state_fips": state_fips} - elif ucgid_str.startswith("5001800US") or ucgid_str.startswith( - "5001900US" - ): + elif ucgid_str.startswith("5001800US") or ucgid_str.startswith("5001900US"): # 5001800US = 118th Congress, 5001900US = 119th Congress state_and_district = ucgid_str[9:] state_fips = int(state_and_district[:2]) district_number = int(state_and_district[2:]) - if district_number == 0 or ( - state_fips == 11 and district_number == 98 - ): + if district_number == 0 or (state_fips == 11 and district_number == 98): district_number = 1 cd_geoid = state_fips * 100 + district_number return { @@ -203,9 +190,7 @@ def get_geographic_strata(session: Session) -> Dict: if not constraints: strata_map["national"] = stratum.stratum_id else: - constraint_vars = { - c.constraint_variable: c.value for c in constraints - } + constraint_vars = {c.constraint_variable: c.value for c in constraints} if "congressional_district_geoid" in constraint_vars: cd_geoid = int(constraint_vars["congressional_district_geoid"]) diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index 7a090d25..9b1e48cb 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -10,9 +10,7 @@ ) -def download( - repo: str, repo_filename: str, local_folder: str, version: str = None -): +def download(repo: str, repo_filename: str, local_folder: str, version: str = None): hf_hub_download( repo_id=repo, @@ -218,15 +216,11 @@ def upload_calibration_artifacts( if log_dir: # Upload run config to calibration/ root for artifact validation - run_config_local = os.path.join( - log_dir, f"{prefix}unified_run_config.json" - ) + run_config_local = os.path.join(log_dir, f"{prefix}unified_run_config.json") if os.path.exists(run_config_local): operations.append( CommitOperationAdd( - path_in_repo=( - f"calibration/{prefix}unified_run_config.json" - ), + path_in_repo=(f"calibration/{prefix}unified_run_config.json"), path_or_fileobj=run_config_local, ) ) diff --git a/policyengine_us_data/utils/loss.py b/policyengine_us_data/utils/loss.py index 51be118b..bfbf49db 100644 --- a/policyengine_us_data/utils/loss.py +++ b/policyengine_us_data/utils/loss.py @@ -166,9 +166,7 @@ def build_loss_matrix(dataset: type, time_period): continue mask = ( - (agi >= row["AGI lower bound"]) - * (agi < row["AGI upper bound"]) - * filer + (agi >= row["AGI lower bound"]) * (agi < row["AGI upper bound"]) * filer ) > 0 if row["Filing status"] == "Single": @@ -188,12 +186,8 @@ def build_loss_matrix(dataset: type, time_period): if row["Count"]: values = (values > 0).astype(float) - agi_range_label = ( - f"{fmt(row['AGI lower bound'])}-{fmt(row['AGI upper bound'])}" - ) - taxable_label = ( - "taxable" if row["Taxable only"] else "all" + " returns" - ) + agi_range_label = f"{fmt(row['AGI lower bound'])}-{fmt(row['AGI upper bound'])}" + taxable_label = "taxable" if row["Taxable only"] else "all" + " returns" filing_status_label = row["Filing status"] variable_label = row["Variable"].replace("_", " ") @@ -272,9 +266,7 @@ def build_loss_matrix(dataset: type, time_period): for variable_name in CBO_PROGRAMS: label = f"nation/cbo/{variable_name}" - loss_matrix[label] = sim.calculate( - variable_name, map_to="household" - ).values + loss_matrix[label] = sim.calculate(variable_name, map_to="household").values if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") param_name = CBO_PARAM_NAME_MAP.get(variable_name, variable_name) @@ -314,9 +306,9 @@ def build_loss_matrix(dataset: type, time_period): # National ACA Enrollment (people receiving a PTC) label = "nation/gov/aca_enrollment" - on_ptc = ( - sim.calculate("aca_ptc", map_to="person", period=2025).values > 0 - ).astype(int) + on_ptc = (sim.calculate("aca_ptc", map_to="person", period=2025).values > 0).astype( + int + ) loss_matrix[label] = sim.map_result(on_ptc, "person", "household") ACA_PTC_ENROLLMENT_2024 = 19_743_689 # people enrolled @@ -348,13 +340,9 @@ def build_loss_matrix(dataset: type, time_period): eitc_eligible_children = sim.calculate("eitc_child_count").values eitc = sim.calculate("eitc").values if row["count_children"] < 2: - meets_child_criteria = ( - eitc_eligible_children == row["count_children"] - ) + meets_child_criteria = eitc_eligible_children == row["count_children"] else: - meets_child_criteria = ( - eitc_eligible_children >= row["count_children"] - ) + meets_child_criteria = eitc_eligible_children >= row["count_children"] loss_matrix[returns_label] = sim.map_result( (eitc > 0) * meets_child_criteria, "tax_unit", @@ -408,9 +396,7 @@ def build_loss_matrix(dataset: type, time_period): # Hard-coded totals for variable_name, target in HARD_CODED_TOTALS.items(): label = f"nation/census/{variable_name}" - loss_matrix[label] = sim.calculate( - variable_name, map_to="household" - ).values + loss_matrix[label] = sim.calculate(variable_name, map_to="household").values if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") targets_array.append(target) @@ -418,8 +404,8 @@ def build_loss_matrix(dataset: type, time_period): # Negative household market income total rough estimate from the IRS SOI PUF market_income = sim.calculate("household_market_income").values - loss_matrix["nation/irs/negative_household_market_income_total"] = ( - market_income * (market_income < 0) + loss_matrix["nation/irs/negative_household_market_income_total"] = market_income * ( + market_income < 0 ) targets_array.append(-138e9) @@ -450,39 +436,27 @@ def build_loss_matrix(dataset: type, time_period): # AGI by SPM threshold totals - spm_threshold_agi = pd.read_csv( - CALIBRATION_FOLDER / "spm_threshold_agi.csv" - ) + spm_threshold_agi = pd.read_csv(CALIBRATION_FOLDER / "spm_threshold_agi.csv") for _, row in spm_threshold_agi.iterrows(): - spm_unit_agi = sim.calculate( - "adjusted_gross_income", map_to="spm_unit" - ).values + spm_unit_agi = sim.calculate("adjusted_gross_income", map_to="spm_unit").values spm_threshold = sim.calculate("spm_unit_spm_threshold").values in_threshold_range = (spm_threshold >= row["lower_spm_threshold"]) * ( spm_threshold < row["upper_spm_threshold"] ) - label = ( - f"nation/census/agi_in_spm_threshold_decile_{int(row['decile'])}" - ) + label = f"nation/census/agi_in_spm_threshold_decile_{int(row['decile'])}" loss_matrix[label] = sim.map_result( in_threshold_range * spm_unit_agi, "spm_unit", "household" ) targets_array.append(row["adjusted_gross_income"]) - label = ( - f"nation/census/count_in_spm_threshold_decile_{int(row['decile'])}" - ) - loss_matrix[label] = sim.map_result( - in_threshold_range, "spm_unit", "household" - ) + label = f"nation/census/count_in_spm_threshold_decile_{int(row['decile'])}" + loss_matrix[label] = sim.map_result(in_threshold_range, "spm_unit", "household") targets_array.append(row["count"]) # Population by state and population under 5 by state - state_population = pd.read_csv( - CALIBRATION_FOLDER / "population_by_state.csv" - ) + state_population = pd.read_csv(CALIBRATION_FOLDER / "population_by_state.csv") for _, row in state_population.iterrows(): in_state = sim.calculate("state_code", map_to="person") == row["state"] @@ -493,9 +467,7 @@ def build_loss_matrix(dataset: type, time_period): under_5 = sim.calculate("age").values < 5 in_state_under_5 = in_state * under_5 label = f"state/census/population_under_5_by_state/{row['state']}" - loss_matrix[label] = sim.map_result( - in_state_under_5, "person", "household" - ) + loss_matrix[label] = sim.map_result(in_state_under_5, "person", "household") targets_array.append(row["population_under_5"]) age = sim.calculate("age").values @@ -519,9 +491,7 @@ def build_loss_matrix(dataset: type, time_period): # SALT tax expenditure targeting - _add_tax_expenditure_targets( - dataset, time_period, sim, loss_matrix, targets_array - ) + _add_tax_expenditure_targets(dataset, time_period, sim, loss_matrix, targets_array) if any(loss_matrix.isna().sum() > 0): raise ValueError("Some targets are missing from the loss matrix") @@ -535,9 +505,7 @@ def build_loss_matrix(dataset: type, time_period): # Overall count by SSN card type label = f"nation/ssa/ssn_card_type_{card_type_str.lower()}_count" - loss_matrix[label] = sim.map_result( - ssn_type_mask, "person", "household" - ) + loss_matrix[label] = sim.map_result(ssn_type_mask, "person", "household") # Target undocumented population by year based on various sources if card_type_str == "NONE": @@ -573,14 +541,11 @@ def build_loss_matrix(dataset: type, time_period): for _, row in spending_by_state.iterrows(): # Households located in this state in_state = ( - sim.calculate("state_code", map_to="household").values - == row["state"] + sim.calculate("state_code", map_to="household").values == row["state"] ) # ACA PTC amounts for every household (2025) - aca_value = sim.calculate( - "aca_ptc", map_to="household", period=2025 - ).values + aca_value = sim.calculate("aca_ptc", map_to="household", period=2025).values # Add a loss-matrix entry and matching target label = f"nation/irs/aca_spending/{row['state'].lower()}" @@ -613,9 +578,7 @@ def build_loss_matrix(dataset: type, time_period): in_state_enrolled = in_state & is_enrolled label = f"state/irs/aca_enrollment/{row['state'].lower()}" - loss_matrix[label] = sim.map_result( - in_state_enrolled, "person", "household" - ) + loss_matrix[label] = sim.map_result(in_state_enrolled, "person", "household") if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") @@ -632,9 +595,7 @@ def build_loss_matrix(dataset: type, time_period): state_person = sim.calculate("state_code", map_to="person").values # Flag people in households that actually receive medicaid - has_medicaid = sim.calculate( - "medicaid_enrolled", map_to="person", period=2025 - ) + has_medicaid = sim.calculate("medicaid_enrolled", map_to="person", period=2025) is_medicaid_eligible = sim.calculate( "is_medicaid_eligible", map_to="person", period=2025 ).values @@ -646,9 +607,7 @@ def build_loss_matrix(dataset: type, time_period): in_state_enrolled = in_state & is_enrolled label = f"irs/medicaid_enrollment/{row['state'].lower()}" - loss_matrix[label] = sim.map_result( - in_state_enrolled, "person", "household" - ) + loss_matrix[label] = sim.map_result(in_state_enrolled, "person", "household") if any(loss_matrix[label].isna()): raise ValueError(f"Missing values for {label}") @@ -672,9 +631,7 @@ def build_loss_matrix(dataset: type, time_period): age_lower_bound = int(age_range.replace("+", "")) age_upper_bound = np.inf else: - age_lower_bound, age_upper_bound = map( - int, age_range.split("-") - ) + age_lower_bound, age_upper_bound = map(int, age_range.split("-")) age_mask = (age >= age_lower_bound) & (age <= age_upper_bound) label = f"state/census/age/{state}/{age_range}" @@ -745,9 +702,7 @@ def apply(self): simulation.default_calculation_period = time_period # Calculate the baseline and reform income tax values. - income_tax_r = simulation.calculate( - "income_tax", map_to="household" - ).values + income_tax_r = simulation.calculate("income_tax", map_to="household").values # Compute the tax expenditure (TE) values. te_values = income_tax_r - income_tax_b @@ -781,9 +736,7 @@ def _add_agi_state_targets(): + soi_targets["VARIABLE"] + "/" + soi_targets.apply( - lambda r: get_agi_band_label( - r["AGI_LOWER_BOUND"], r["AGI_UPPER_BOUND"] - ), + lambda r: get_agi_band_label(r["AGI_LOWER_BOUND"], r["AGI_UPPER_BOUND"]), axis=1, ) ) @@ -804,9 +757,7 @@ def _add_agi_metric_columns( agi = sim.calculate("adjusted_gross_income").values state = sim.calculate("state_code", map_to="person").values - state = sim.map_result( - state, "person", "tax_unit", how="value_from_first_person" - ) + state = sim.map_result(state, "person", "tax_unit", how="value_from_first_person") for _, r in soi_targets.iterrows(): lower, upper = r.AGI_LOWER_BOUND, r.AGI_UPPER_BOUND @@ -850,13 +801,9 @@ def _add_state_real_estate_taxes(loss_matrix, targets_list, sim): rtol=1e-8, ), "Real estate tax totals do not sum to national target" - targets_list.extend( - real_estate_taxes_targets["real_estate_taxes_bn"].tolist() - ) + targets_list.extend(real_estate_taxes_targets["real_estate_taxes_bn"].tolist()) - real_estate_taxes = sim.calculate( - "real_estate_taxes", map_to="household" - ).values + real_estate_taxes = sim.calculate("real_estate_taxes", map_to="household").values state = sim.calculate("state_code", map_to="household").values for _, r in real_estate_taxes_targets.iterrows(): @@ -879,22 +826,16 @@ def _add_snap_state_targets(sim): ).calibration.gov.cbo._children["snap"] ratio = snap_targets[["Cost"]].sum().values[0] / national_cost_target snap_targets[["CostAdj"]] = snap_targets[["Cost"]] / ratio - assert ( - np.round(snap_targets[["CostAdj"]].sum().values[0]) - == national_cost_target - ) + assert np.round(snap_targets[["CostAdj"]].sum().values[0]) == national_cost_target cost_targets = snap_targets.copy()[["GEO_ID", "CostAdj"]] - cost_targets["target_name"] = ( - cost_targets["GEO_ID"].str[-4:] + "/snap-cost" - ) + cost_targets["target_name"] = cost_targets["GEO_ID"].str[-4:] + "/snap-cost" hh_targets = snap_targets.copy()[["GEO_ID", "Households"]] hh_targets["target_name"] = snap_targets["GEO_ID"].str[-4:] + "/snap-hhs" target_names = ( - cost_targets["target_name"].tolist() - + hh_targets["target_name"].tolist() + cost_targets["target_name"].tolist() + hh_targets["target_name"].tolist() ) target_values = ( cost_targets["CostAdj"].astype(float).tolist() @@ -913,14 +854,12 @@ def _add_snap_metric_columns( snap_targets = pd.read_csv(CALIBRATION_FOLDER / "snap_state.csv") snap_cost = sim.calculate("snap_reported", map_to="household").values - snap_hhs = ( - sim.calculate("snap_reported", map_to="household").values > 0 - ).astype(int) + snap_hhs = (sim.calculate("snap_reported", map_to="household").values > 0).astype( + int + ) state = sim.calculate("state_code", map_to="person").values - state = sim.map_result( - state, "person", "household", how="value_from_first_person" - ) + state = sim.map_result(state, "person", "household", how="value_from_first_person") STATE_ABBR_TO_FIPS["DC"] = 11 state_fips = pd.Series(state).apply(lambda s: STATE_ABBR_TO_FIPS[s]) @@ -939,9 +878,7 @@ def _add_snap_metric_columns( return loss_matrix -def print_reweighting_diagnostics( - optimised_weights, loss_matrix, targets_array, label -): +def print_reweighting_diagnostics(optimised_weights, loss_matrix, targets_array, label): # Convert all inputs to NumPy arrays right at the start optimised_weights_np = ( optimised_weights.numpy() @@ -968,9 +905,7 @@ def print_reweighting_diagnostics( # All subsequent calculations use the guaranteed NumPy versions estimate = optimised_weights_np @ loss_matrix_np - rel_error = ( - ((estimate - targets_array_np) + 1) / (targets_array_np + 1) - ) ** 2 + rel_error = (((estimate - targets_array_np) + 1) / (targets_array_np + 1)) ** 2 within_10_percent_mask = np.abs(estimate - targets_array_np) <= ( 0.10 * np.abs(targets_array_np) ) diff --git a/policyengine_us_data/utils/randomness.py b/policyengine_us_data/utils/randomness.py index eac01522..001dbf2f 100644 --- a/policyengine_us_data/utils/randomness.py +++ b/policyengine_us_data/utils/randomness.py @@ -11,9 +11,7 @@ def _stable_string_hash(s: str) -> np.uint64: Ported from policyengine_core.commons.formulas._stable_string_hash. """ with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "overflow encountered", RuntimeWarning - ) + warnings.filterwarnings("ignore", "overflow encountered", RuntimeWarning) h = np.uint64(0) for byte in s.encode("utf-8"): h = h * np.uint64(31) + np.uint64(byte) diff --git a/policyengine_us_data/utils/soi.py b/policyengine_us_data/utils/soi.py index d9538add..b9755c30 100644 --- a/policyengine_us_data/utils/soi.py +++ b/policyengine_us_data/utils/soi.py @@ -11,9 +11,7 @@ def pe_to_soi(pe_dataset, year): pe_sim.default_calculation_period = year df = pd.DataFrame() - pe = lambda variable: np.array( - pe_sim.calculate(variable, map_to="tax_unit") - ) + pe = lambda variable: np.array(pe_sim.calculate(variable, map_to="tax_unit")) df["adjusted_gross_income"] = pe("adjusted_gross_income") df["exemption"] = pe("exemptions") @@ -51,12 +49,8 @@ def pe_to_soi(pe_dataset, year): df["total_pension_income"] = pe("pension_income") df["taxable_pension_income"] = pe("taxable_pension_income") df["qualified_dividends"] = pe("qualified_dividend_income") - df["rent_and_royalty_net_income"] = pe("rental_income") * ( - pe("rental_income") > 0 - ) - df["rent_and_royalty_net_losses"] = -pe("rental_income") * ( - pe("rental_income") < 0 - ) + df["rent_and_royalty_net_income"] = pe("rental_income") * (pe("rental_income") > 0) + df["rent_and_royalty_net_losses"] = -pe("rental_income") * (pe("rental_income") < 0) df["total_social_security"] = pe("social_security") df["taxable_social_security"] = pe("taxable_social_security") df["income_tax_before_credits"] = pe("income_tax_before_credits") @@ -176,8 +170,7 @@ def get_soi(year: int) -> pd.DataFrame: pe_name = uprating_map.get(variable) if pe_name in uprating.index: uprating_factors[variable] = ( - uprating.loc[pe_name, year] - / uprating.loc[pe_name, soi.Year.max()] + uprating.loc[pe_name, year] / uprating.loc[pe_name, soi.Year.max()] ) else: uprating_factors[variable] = ( @@ -218,9 +211,7 @@ def compare_soi_replication_to_soi(df, soi): elif fs == "Head of Household": subset = subset[subset.filing_status == "HEAD_OF_HOUSEHOLD"] elif fs == "Married Filing Jointly/Surviving Spouse": - subset = subset[ - subset.filing_status.isin(["JOINT", "SURVIVING_SPOUSE"]) - ] + subset = subset[subset.filing_status.isin(["JOINT", "SURVIVING_SPOUSE"])] elif fs == "Married Filing Separately": subset = subset[subset.filing_status == "SEPARATE"] @@ -258,17 +249,13 @@ def compare_soi_replication_to_soi(df, soi): } ) - soi_replication["Error"] = ( - soi_replication["Value"] - soi_replication["SOI Value"] - ) + soi_replication["Error"] = soi_replication["Value"] - soi_replication["SOI Value"] soi_replication["Absolute error"] = soi_replication["Error"].abs() soi_replication["Relative error"] = ( (soi_replication["Error"] / soi_replication["SOI Value"]) .replace([np.inf, -np.inf], np.nan) .fillna(0) ) - soi_replication["Absolute relative error"] = soi_replication[ - "Relative error" - ].abs() + soi_replication["Absolute relative error"] = soi_replication["Relative error"].abs() return soi_replication diff --git a/policyengine_us_data/utils/spm.py b/policyengine_us_data/utils/spm.py index b2e4538b..ad3c9e9f 100644 --- a/policyengine_us_data/utils/spm.py +++ b/policyengine_us_data/utils/spm.py @@ -44,9 +44,7 @@ def calculate_spm_thresholds_with_geoadj( for i in range(n): tenure_str = TENURE_CODE_MAP.get(int(tenure_codes[i]), "renter") base = base_thresholds[tenure_str] - equiv_scale = spm_equivalence_scale( - int(num_adults[i]), int(num_children[i]) - ) + equiv_scale = spm_equivalence_scale(int(num_adults[i]), int(num_children[i])) thresholds[i] = base * equiv_scale * geoadj[i] return thresholds diff --git a/policyengine_us_data/utils/uprating.py b/policyengine_us_data/utils/uprating.py index 6dd2f89c..41d223b0 100644 --- a/policyengine_us_data/utils/uprating.py +++ b/policyengine_us_data/utils/uprating.py @@ -23,9 +23,7 @@ def create_policyengine_uprating_factors_table(): parameter = system.parameters.get_child(variable.uprating) start_value = parameter(START_YEAR) for year in range(START_YEAR, END_YEAR + 1): - population_growth = population_size(year) / population_size( - START_YEAR - ) + population_growth = population_size(year) / population_size(START_YEAR) variable_names.append(variable.name) years.append(year) growth = parameter(year) / start_value diff --git a/tests/test_h6_reform.py b/tests/test_h6_reform.py index e68ed8db..2acdd8cc 100644 --- a/tests/test_h6_reform.py +++ b/tests/test_h6_reform.py @@ -27,17 +27,13 @@ def calculate_oasdi_thresholds(year: int) -> tuple[int, int]: return oasdi_single, oasdi_joint -def get_swapped_thresholds( - oasdi_threshold: int, hi_threshold: int -) -> tuple[int, int]: +def get_swapped_thresholds(oasdi_threshold: int, hi_threshold: int) -> tuple[int, int]: """ Apply min/max swap to handle threshold crossover. Returns (base_threshold, adjusted_threshold) where base <= adjusted. """ - return min(oasdi_threshold, hi_threshold), max( - oasdi_threshold, hi_threshold - ) + return min(oasdi_threshold, hi_threshold), max(oasdi_threshold, hi_threshold) def needs_crossover_swap(oasdi_threshold: int, hi_threshold: int) -> bool: @@ -145,9 +141,7 @@ def test_single_crossover_starts_2046(self): # 2046+: crossover for year in range(2046, 2054): oasdi_single, _ = calculate_oasdi_thresholds(year) - assert needs_crossover_swap( - oasdi_single, HI_SINGLE - ), f"Year {year}" + assert needs_crossover_swap(oasdi_single, HI_SINGLE), f"Year {year}" class TestH6ThresholdSwapping: @@ -211,9 +205,9 @@ def test_2045_error_analysis(self): assert single_error_swapped == pytest.approx(225) assert joint_error_default == pytest.approx(3_150) - assert joint_error_default / single_error_swapped == pytest.approx( - 14.0 - ), "Swapped rates should have 14x less error" + assert joint_error_default / single_error_swapped == pytest.approx(14.0), ( + "Swapped rates should have 14x less error" + ) def test_swapped_rates_align_with_tax_cut_intent(self): """Swapped rates undertax (not overtax), aligning with reform intent.""" diff --git a/tests/test_no_formula_variables_stored.py b/tests/test_no_formula_variables_stored.py index 9334a5c7..7c7cb0de 100644 --- a/tests/test_no_formula_variables_stored.py +++ b/tests/test_no_formula_variables_stored.py @@ -109,11 +109,7 @@ def test_stored_values_match_computed( computed_total = np.sum(computed.astype(float)) if abs(stored_total) > 0: - pct_diff = ( - abs(stored_total - computed_total) - / abs(stored_total) - * 100 - ) + pct_diff = abs(stored_total - computed_total) / abs(stored_total) * 100 else: pct_diff = 0 @@ -141,23 +137,13 @@ def test_ss_subcomponents_sum_to_computed_total(sim, dataset_path): stored in the dataset sum to the simulation's computed total. """ with h5py.File(dataset_path, "r") as f: - ss_retirement = f["social_security_retirement"]["2024"][...].astype( - float - ) - ss_disability = f["social_security_disability"]["2024"][...].astype( - float - ) - ss_survivors = f["social_security_survivors"]["2024"][...].astype( - float - ) - ss_dependents = f["social_security_dependents"]["2024"][...].astype( - float - ) + ss_retirement = f["social_security_retirement"]["2024"][...].astype(float) + ss_disability = f["social_security_disability"]["2024"][...].astype(float) + ss_survivors = f["social_security_survivors"]["2024"][...].astype(float) + ss_dependents = f["social_security_dependents"]["2024"][...].astype(float) sub_sum = ss_retirement + ss_disability + ss_survivors + ss_dependents - computed_total = np.array(sim.calculate("social_security", 2024)).astype( - float - ) + computed_total = np.array(sim.calculate("social_security", 2024)).astype(float) # Only check records that have any SS income has_ss = computed_total > 0 diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 1ec097a7..25755f0a 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -144,9 +144,9 @@ def test_output_checksums(self): if file_path.exists() and filename != "checksums.txt": with open(file_path, "rb") as f: actual_checksum = hashlib.sha256(f.read()).hexdigest() - assert ( - actual_checksum == expected_checksum - ), f"Checksum mismatch for {filename}" + assert actual_checksum == expected_checksum, ( + f"Checksum mismatch for {filename}" + ) def test_memory_usage(self): """Test that memory usage stays within bounds.""" diff --git a/tests/test_weeks_unemployed.py b/tests/test_weeks_unemployed.py index 18aa4762..d64d8b64 100644 --- a/tests/test_weeks_unemployed.py +++ b/tests/test_weeks_unemployed.py @@ -21,9 +21,9 @@ def test_lkweeks_in_person_columns(self): # Check for correct variable assert '"LKWEEKS"' in content, "LKWEEKS should be in PERSON_COLUMNS" - assert ( - '"WKSUNEM"' not in content - ), "WKSUNEM should not be in PERSON_COLUMNS (Census uses LKWEEKS)" + assert '"WKSUNEM"' not in content, ( + "WKSUNEM should not be in PERSON_COLUMNS (Census uses LKWEEKS)" + ) def test_cps_uses_lkweeks(self): """Test that cps.py uses LKWEEKS, not WKSUNEM.""" diff --git a/validation/benefit_validation.py b/validation/benefit_validation.py index d614ae03..cf468972 100644 --- a/validation/benefit_validation.py +++ b/validation/benefit_validation.py @@ -50,9 +50,7 @@ def analyze_benefit_underreporting(): # Participation participants = (benefit > 0).sum() - weighted_participants = ( - (benefit > 0) * weight - ).sum() / 1e6 # millions + weighted_participants = ((benefit > 0) * weight).sum() / 1e6 # millions # Underreporting factor underreporting = info["admin_total"] / total if total > 0 else np.inf @@ -168,9 +166,7 @@ def earnings_reform(parameters): earnings_change = earnings * pct_increase / 100 net_change = reformed_net - original_net - emtr = np.where( - earnings_change > 0, 1 - (net_change / earnings_change), 0 - ) + emtr = np.where(earnings_change > 0, 1 - (net_change / earnings_change), 0) # Focus on sample sample_emtr = emtr[sample] @@ -254,9 +250,7 @@ def analyze_aca_subsidies(): total_ptc = (ptc[mask] * weight[mask]).sum() / 1e9 recipients = ((ptc > 0) & mask).sum() weighted_recipients = (((ptc > 0) & mask) * weight).sum() / 1e6 - mean_ptc = ( - ptc[(ptc > 0) & mask].mean() if ((ptc > 0) & mask).any() else 0 - ) + mean_ptc = ptc[(ptc > 0) & mask].mean() if ((ptc > 0) & mask).any() else 0 results.append( { @@ -307,9 +301,7 @@ def generate_benefit_validation_report(): print("\n\n4. Top 10 States by SNAP Benefits") print("-" * 40) state_df = validate_state_benefits() - top_states = state_df.nlargest(10, "snap_billions")[ - ["state_code", "snap_billions"] - ] + top_states = state_df.nlargest(10, "snap_billions")[["state_code", "snap_billions"]] print(top_states.to_string(index=False)) # ACA analysis @@ -319,9 +311,7 @@ def generate_benefit_validation_report(): print(aca_df.to_string(index=False)) # Save results - underreporting_df.to_csv( - "validation/benefit_underreporting.csv", index=False - ) + underreporting_df.to_csv("validation/benefit_underreporting.csv", index=False) interactions_df.to_csv("validation/program_interactions.csv", index=False) emtr_df.to_csv("validation/effective_marginal_tax_rates.csv", index=False) state_df.to_csv("validation/state_benefit_totals.csv", index=False) diff --git a/validation/generate_qrf_statistics.py b/validation/generate_qrf_statistics.py index 4a026dea..4015fe1e 100644 --- a/validation/generate_qrf_statistics.py +++ b/validation/generate_qrf_statistics.py @@ -222,18 +222,14 @@ print(support_df.round(3).to_string()) print("\nSummary:") -print( - f"- Average overlap coefficient: {support_df['overlap_coefficient'].mean():.3f}" -) +print(f"- Average overlap coefficient: {support_df['overlap_coefficient'].mean():.3f}") print( f"- All overlap coefficients > 0.85: {(support_df['overlap_coefficient'] > 0.85).all()}" ) print( f"- Variables with SMD > 0.25: {(support_df['standardized_mean_diff'] > 0.25).sum()}" ) -print( - f"- All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}" -) +print(f"- All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}") print( f"- Variables with significant KS test (p<0.05): {(support_df['ks_pvalue'] < 0.05).sum()}" ) @@ -279,9 +275,7 @@ print( f"- All correlation differences < 0.05: {(joint_df['correlation_diff'] < 0.05).all()}" ) -print( - f"- Average correlation difference: {joint_df['correlation_diff'].mean():.3f}" -) +print(f"- Average correlation difference: {joint_df['correlation_diff'].mean():.3f}") # Save all results print("\n\nSAVING RESULTS...") @@ -294,9 +288,7 @@ ) accuracy_df.to_csv("validation/outputs/qrf_accuracy_metrics.csv") -print( - "✓ Saved accuracy metrics to validation/outputs/qrf_accuracy_metrics.csv" -) +print("✓ Saved accuracy metrics to validation/outputs/qrf_accuracy_metrics.csv") joint_df.to_csv("validation/outputs/joint_distribution_tests.csv", index=False) print( @@ -309,9 +301,7 @@ f.write("=" * 40 + "\n\n") for var, r2 in variance_explained.items(): f.write(f"{var.replace('_', ' ').title()}: {r2 * 100:.0f}%\n") -print( - "✓ Saved variance explained to validation/outputs/variance_explained.txt" -) +print("✓ Saved variance explained to validation/outputs/variance_explained.txt") # Create summary report with open("validation/outputs/qrf_diagnostics_summary.txt", "w") as f: @@ -327,12 +317,8 @@ f.write( f"All overlap coefficients > 0.85: {(support_df['overlap_coefficient'] > 0.85).all()}\n" ) - f.write( - f"All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}\n" - ) - f.write( - f"All KS tests p > 0.05: {(support_df['ks_pvalue'] > 0.05).all()}\n\n" - ) + f.write(f"All SMDs < 0.25: {(support_df['standardized_mean_diff'] < 0.25).all()}\n") + f.write(f"All KS tests p > 0.05: {(support_df['ks_pvalue'] > 0.05).all()}\n\n") f.write("2. VARIANCE EXPLAINED\n") f.write("-" * 40 + "\n") @@ -361,9 +347,7 @@ ) f.write("\n" + "=" * 60 + "\n") - f.write( - "These statistics demonstrate that the QRF methodology successfully:\n" - ) + f.write("These statistics demonstrate that the QRF methodology successfully:\n") f.write("- Maintains strong common support between datasets\n") f.write("- Achieves high predictive accuracy for imputation\n") f.write("- Preserves joint distributions of variables\n") diff --git a/validation/qrf_diagnostics.py b/validation/qrf_diagnostics.py index 4e572916..d22f883c 100644 --- a/validation/qrf_diagnostics.py +++ b/validation/qrf_diagnostics.py @@ -28,9 +28,7 @@ def analyze_common_support(cps_data, puf_data, predictors): # Overlap coefficient (Weitzman 1970) # OVL = sum(min(f(x), g(x))) where f,g are densities - bins = np.histogram_bin_edges( - np.concatenate([cps_dist, puf_dist]), bins=50 - ) + bins = np.histogram_bin_edges(np.concatenate([cps_dist, puf_dist]), bins=50) cps_hist, _ = np.histogram(cps_dist, bins=bins, density=True) puf_hist, _ = np.histogram(puf_dist, bins=bins, density=True) @@ -81,9 +79,7 @@ def validate_qrf_accuracy(puf_data, predictors, target_vars, n_estimators=100): ) # Fit QRF - qrf = RandomForestQuantileRegressor( - n_estimators=n_estimators, random_state=42 - ) + qrf = RandomForestQuantileRegressor(n_estimators=n_estimators, random_state=42) qrf.fit(X_train, y_train) # Predictions at multiple quantiles @@ -124,9 +120,7 @@ def validate_qrf_accuracy(puf_data, predictors, target_vars, n_estimators=100): "qrf_rmse": rmse, "hotdeck_mae": hotdeck_mae, "linear_mae": lr_mae, - "qrf_improvement_vs_hotdeck": (hotdeck_mae - mae) - / hotdeck_mae - * 100, + "qrf_improvement_vs_hotdeck": (hotdeck_mae - mae) / hotdeck_mae * 100, "qrf_improvement_vs_linear": (lr_mae - mae) / lr_mae * 100, "coverage_90pct": coverage_90, "coverage_50pct": coverage_50, @@ -135,9 +129,7 @@ def validate_qrf_accuracy(puf_data, predictors, target_vars, n_estimators=100): return pd.DataFrame(results).T -def test_joint_distribution_preservation( - original_data, imputed_data, var_pairs -): +def test_joint_distribution_preservation(original_data, imputed_data, var_pairs): """Test whether joint distributions are preserved in imputation.""" results = [] @@ -159,12 +151,12 @@ def test_joint_distribution_preservation( # Joint distribution test (2D KS test approximation) # Using average of marginal KS statistics - ks1 = stats.ks_2samp( - original_data[var1].dropna(), imputed_data[var1].dropna() - )[0] - ks2 = stats.ks_2samp( - original_data[var2].dropna(), imputed_data[var2].dropna() - )[0] + ks1 = stats.ks_2samp(original_data[var1].dropna(), imputed_data[var1].dropna())[ + 0 + ] + ks2 = stats.ks_2samp(original_data[var2].dropna(), imputed_data[var2].dropna())[ + 0 + ] joint_ks = (ks1 + ks2) / 2 results.append( @@ -281,9 +273,7 @@ def generate_qrf_diagnostic_report(cps_data, puf_data, imputed_data): print( f"- Average QRF improvement vs linear: {accuracy_df['qrf_improvement_vs_linear'].mean():.1f}%" ) - print( - f"- Average 90% coverage: {accuracy_df['coverage_90pct'].mean():.3f}" - ) + print(f"- Average 90% coverage: {accuracy_df['coverage_90pct'].mean():.3f}") # Joint distribution preservation print("\n\n3. Joint Distribution Preservation") @@ -295,16 +285,12 @@ def generate_qrf_diagnostic_report(cps_data, puf_data, imputed_data): ("pension_income", "social_security"), ] - joint_df = test_joint_distribution_preservation( - puf_data, imputed_data, var_pairs - ) + joint_df = test_joint_distribution_preservation(puf_data, imputed_data, var_pairs) print(joint_df.to_string(index=False)) # Create diagnostic plots create_diagnostic_plots(cps_data, puf_data, predictors) - print( - "\n\nDiagnostic plots saved to validation/common_support_diagnostics.png" - ) + print("\n\nDiagnostic plots saved to validation/common_support_diagnostics.png") # Save results support_df.to_csv("validation/common_support_analysis.csv") diff --git a/validation/tax_policy_validation.py b/validation/tax_policy_validation.py index c7c4f600..9e04982f 100644 --- a/validation/tax_policy_validation.py +++ b/validation/tax_policy_validation.py @@ -101,9 +101,7 @@ def analyze_high_income_taxpayers(): for threshold in thresholds: count = (weights[agi >= threshold]).sum() pct_returns = count / weights.sum() * 100 - total_agi = ( - agi[agi >= threshold] * weights[agi >= threshold] - ).sum() / 1e9 + total_agi = (agi[agi >= threshold] * weights[agi >= threshold]).sum() / 1e9 results.append( { @@ -135,9 +133,7 @@ def validate_state_revenues(): results.append({"state_code": state, "revenue_billions": total}) - return pd.DataFrame(results).sort_values( - "revenue_billions", ascending=False - ) + return pd.DataFrame(results).sort_values("revenue_billions", ascending=False) def generate_validation_report(): diff --git a/validation/validate_retirement_imputation.py b/validation/validate_retirement_imputation.py index 6a11eafd..065a8294 100644 --- a/validation/validate_retirement_imputation.py +++ b/validation/validate_retirement_imputation.py @@ -54,12 +54,8 @@ def validate_constraints(sim) -> list: issues = [] year = 2024 - emp_income = sim.calculate( - "employment_income", year, map_to="person" - ).values - se_income = sim.calculate( - "self_employment_income", year, map_to="person" - ).values + emp_income = sim.calculate("employment_income", year, map_to="person").values + se_income = sim.calculate("self_employment_income", year, map_to="person").values age = sim.calculate("age", year, map_to="person").values catch_up = age >= 50 @@ -79,9 +75,7 @@ def validate_constraints(sim) -> list: n_over_cap = (vals > max_401k + 1).sum() if n_over_cap > 0: - issues.append( - f"FAIL: {var} has {n_over_cap} values exceeding 401k cap" - ) + issues.append(f"FAIL: {var} has {n_over_cap} values exceeding 401k cap") zero_wage = emp_income == 0 n_nonzero_no_wage = (vals[zero_wage] > 0).sum() @@ -110,9 +104,7 @@ def validate_constraints(sim) -> list: n_over_cap = (vals > max_ira + 1).sum() if n_over_cap > 0: - issues.append( - f"FAIL: {var} has {n_over_cap} values exceeding IRA cap" - ) + issues.append(f"FAIL: {var} has {n_over_cap} values exceeding IRA cap") # SE pension constraint var = "self_employed_pension_contributions" @@ -141,9 +133,7 @@ def validate_aggregates(sim) -> list: weight = sim.calculate("person_weight", year).values - logger.info( - "\n%-45s %15s %15s %10s", "Variable", "Weighted Sum", "Target", "Ratio" - ) + logger.info("\n%-45s %15s %15s %10s", "Variable", "Weighted Sum", "Target", "Ratio") logger.info("-" * 90) for var, target in TARGETS.items(): From 514d05d4f6908b80b2e295a7c1acae0c946bccac Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 16 Mar 2026 20:55:06 +0100 Subject: [PATCH 8/8] Apply ruff formatting Co-Authored-By: Claude Opus 4.6 --- policyengine_us_data/db/etl_age.py | 1 - policyengine_us_data/db/etl_irs_soi.py | 1 - policyengine_us_data/db/etl_medicaid.py | 1 - policyengine_us_data/db/etl_snap.py | 1 - .../tests/test_datasets/test_enhanced_cps.py | 2 -- .../test_datasets/test_sparse_enhanced_cps.py | 4 ---- policyengine_us_data/utils/hdfstore.py | 16 ++++------------ policyengine_us_data/utils/huggingface.py | 1 - 8 files changed, 4 insertions(+), 23 deletions(-) diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index db5e54da..658b76f0 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -92,7 +92,6 @@ def transform_age_data(age_data, docs): def load_age_data(df_long, geo, year): - # Quick data quality check before loading ---- if geo == "National": assert len(set(df_long.ucgid_str)) == 1 diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index f2b17795..b5999e48 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -238,7 +238,6 @@ def extract_soi_data() -> pd.DataFrame: def transform_soi_data(raw_df): - TARGETS = [ dict(code="59661", name="eitc", breakdown=("eitc_child_count", 0)), dict(code="59662", name="eitc", breakdown=("eitc_child_count", 1)), diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index 2c467799..30bae90a 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -150,7 +150,6 @@ def transform_survey_medicaid_data(cd_survey_df): def load_medicaid_data(long_state, long_cd, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index dc5975a4..d21260d1 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -153,7 +153,6 @@ def transform_survey_snap_data(raw_df): def load_administrative_snap_data(df_states, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) diff --git a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py index 298de5a4..5ebad600 100644 --- a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py @@ -145,7 +145,6 @@ def test_undocumented_matches_ssn_none(): def test_aca_calibration(): - import pandas as pd from pathlib import Path from policyengine_us import Microsimulation @@ -231,7 +230,6 @@ def test_immigration_status_diversity(): def test_medicaid_calibration(): - import pandas as pd from pathlib import Path from policyengine_us import Microsimulation diff --git a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py index a7ee941b..0314d8e4 100644 --- a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py @@ -123,7 +123,6 @@ def test_sparse_ecps_replicates_jct_tax_expenditures(): def deprecated_test_sparse_ecps_replicates_jct_tax_expenditures_full(sim): - # JCT tax expenditure targets EXPENDITURE_TARGETS = { "salt_deduction": 21.247e9, @@ -158,7 +157,6 @@ def apply(self): def test_sparse_ssn_card_type_none_target(sim): - TARGET_COUNT = 13e6 TOLERANCE = 0.2 # Allow 20% error @@ -176,7 +174,6 @@ def test_sparse_ssn_card_type_none_target(sim): def test_sparse_aca_calibration(sim): - TARGETS_PATH = Path( "policyengine_us_data/storage/calibration_targets/aca_spending_and_enrollment_2024.csv" ) @@ -210,7 +207,6 @@ def test_sparse_aca_calibration(sim): def test_sparse_medicaid_calibration(sim): - TARGETS_PATH = Path( "policyengine_us_data/storage/calibration_targets/medicaid_enrollment_2024.csv" ) diff --git a/policyengine_us_data/utils/hdfstore.py b/policyengine_us_data/utils/hdfstore.py index 9a43a0ef..bf24b8fa 100644 --- a/policyengine_us_data/utils/hdfstore.py +++ b/policyengine_us_data/utils/hdfstore.py @@ -54,9 +54,7 @@ def split_data_into_entity_dfs( cols = {} for var_name in entity_vars[entity]: periods = data[var_name] - tp_key = ( - time_period if time_period in periods else str(time_period) - ) + tp_key = time_period if time_period in periods else str(time_period) if tp_key not in periods: continue arr = periods[tp_key] @@ -69,11 +67,7 @@ def split_data_into_entity_dfs( ref_col = f"person_{ref_entity}_id" if ref_col in data: periods = data[ref_col] - tp_key = ( - time_period - if time_period in periods - else str(time_period) - ) + tp_key = time_period if time_period in periods else str(time_period) if tp_key in periods: cols[ref_col] = periods[tp_key] @@ -110,9 +104,7 @@ def build_uprating_manifest( ) uprating = "" if var_name in system.variables: - uprating = ( - getattr(system.variables[var_name], "uprating", None) or "" - ) + uprating = getattr(system.variables[var_name], "uprating", None) or "" records.append( { "variable": var_name, @@ -168,7 +160,7 @@ def save_hdfstore( ) for entity_name, df in entity_dfs.items(): - print(f" {entity_name}: {len(df):,} rows, " f"{len(df.columns)} cols") + print(f" {entity_name}: {len(df):,} rows, {len(df.columns)} cols") print(f" manifest: {len(manifest_df)} variables") print("HDFStore saved successfully!") diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index 9b1e48cb..20f96d0d 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -11,7 +11,6 @@ def download(repo: str, repo_filename: str, local_folder: str, version: str = None): - hf_hub_download( repo_id=repo, repo_type="model",