diff --git a/changelog.d/report-run-timestamps.fixed.md b/changelog.d/report-run-timestamps.fixed.md new file mode 100644 index 000000000..8e1ae31d0 --- /dev/null +++ b/changelog.d/report-run-timestamps.fixed.md @@ -0,0 +1 @@ +Expose report output run timestamps for report responses. diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index edd647906..92663ac42 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -1,24 +1,23 @@ -from policyengine_api.country import ( - COUNTRIES, -) from policyengine_api.data import database, local_database import json from flask import Response, request -from policyengine_api.utils import hash_object from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS -import sqlalchemy.exc -from policyengine_api.country import COUNTRIES -import json import logging from datetime import date from policyengine_api.utils.payload_validators import validate_country -def add_yearly_variables(household, country_id): +def get_countries(): + from policyengine_api.country import COUNTRIES + + return COUNTRIES + + +def add_yearly_variables(household, country_id, countries=None): """ Add yearly variables to a household dict before enqueueing calculation """ - metadata = COUNTRIES.get(country_id).metadata + metadata = (countries or get_countries()).get(country_id).metadata variables = metadata["variables"] entities = metadata["entities"] @@ -35,8 +34,8 @@ def add_yearly_variables(household, country_id): possible_entities = household[entity_plural].keys() for entity in possible_entities: if ( - not variables[variable]["name"] - in household[entity_plural][entity] + variables[variable]["name"] + not in household[entity_plural][entity] ): if variables[variable]["isInputVariable"]: household[entity_plural][entity][ @@ -85,7 +84,7 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st # Look in computed_households to see if already computed row = local_database.query( - f"SELECT * FROM computed_household WHERE household_id = ? AND policy_id = ? AND api_version = ?", + "SELECT * FROM computed_household WHERE household_id = ? AND policy_id = ? AND api_version = ?", (household_id, policy_id, api_version), ).fetchone() @@ -109,7 +108,7 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st # Retrieve from the household table row = database.query( - f"SELECT * FROM household WHERE id = ? AND country_id = ?", + "SELECT * FROM household WHERE id = ? AND country_id = ?", (household_id, country_id), ).fetchone() @@ -135,7 +134,7 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st # Retrieve from the policy table row = database.query( - f"SELECT * FROM policy WHERE id = ? AND country_id = ?", + "SELECT * FROM policy WHERE id = ? AND country_id = ?", (policy_id, country_id), ).fetchone() @@ -153,7 +152,7 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st mimetype="application/json", ) - country = COUNTRIES.get(country_id) + country = get_countries().get(country_id) try: result = country.calculate( @@ -178,7 +177,7 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st try: local_database.query( - f"INSERT INTO computed_household (country_id, household_id, policy_id, computed_household_json, api_version) VALUES (?, ?, ?, ?, ?)", + "INSERT INTO computed_household (country_id, household_id, policy_id, computed_household_json, api_version) VALUES (?, ?, ?, ?, ?)", ( country_id, household_id, @@ -190,7 +189,7 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st except Exception: # Update the result if it already exists local_database.query( - f"UPDATE computed_household SET computed_household_json = ? WHERE country_id = ? AND household_id = ? AND policy_id = ?", + "UPDATE computed_household SET computed_household_json = ? WHERE country_id = ? AND household_id = ? AND policy_id = ?", (json.dumps(result), country_id, household_id, policy_id), ) @@ -217,7 +216,7 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict: # Add in any missing yearly variables to household_json household_json = add_yearly_variables(household_json, country_id) - country = COUNTRIES.get(country_id) + country = get_countries().get(country_id) try: result = country.calculate(household_json, policy_json) diff --git a/policyengine_api/openapi_spec.yaml b/policyengine_api/openapi_spec.yaml index 77daadc9e..c8daa82ae 100644 --- a/policyengine_api/openapi_spec.yaml +++ b/policyengine_api/openapi_spec.yaml @@ -104,7 +104,7 @@ paths: data: type: object responses: - 201: + "201": description: OK content: application/json: @@ -151,7 +151,7 @@ paths: schema: type: integer responses: - 200: + "200": description: The policy record. content: application/json: @@ -219,7 +219,7 @@ paths: schema: type: string responses: - 200: + "200": description: The search results. content: application/json: @@ -560,6 +560,177 @@ paths: type: string message: type: string + /{country_id}/report: + post: + summary: Create a report output + operationId: create_report_output + description: Create or retrieve a report output for the provided simulations and year. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + simulation_1_id: + type: integer + simulation_2_id: + type: integer + nullable: true + year: + type: string + responses: + "200": + description: Existing report output. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + nullable: true + result: + type: object + properties: + requested_at: + type: string + nullable: true + started_at: + type: string + nullable: true + finished_at: + type: string + nullable: true + "201": + description: Created report output. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + nullable: true + result: + type: object + properties: + requested_at: + type: string + nullable: true + started_at: + type: string + nullable: true + finished_at: + type: string + nullable: true + patch: + summary: Update a report output + operationId: update_report_output + description: Update a report output status, result, or error message. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + id: + type: integer + status: + type: string + output: + type: object + nullable: true + error_message: + type: string + nullable: true + responses: + "200": + description: Updated report output. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + nullable: true + result: + type: object + properties: + requested_at: + type: string + nullable: true + started_at: + type: string + nullable: true + finished_at: + type: string + nullable: true + /{country_id}/report/{report_id}: + get: + summary: Get a report output + operationId: get_report_output + description: Get a report output by ID. Timestamp fields are projected from the selected base report run. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + - name: report_id + in: path + description: The report output ID. + required: true + schema: + type: integer + responses: + "200": + description: Report output. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + nullable: true + result: + type: object + properties: + requested_at: + type: string + nullable: true + started_at: + type: string + nullable: true + finished_at: + type: string + nullable: true /{country_id}/economy/{policy_id}/over/{baseline_policy_id}: get: summary: Calculate the economic impact of a policy diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 1100faf97..48a2ac43a 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -118,6 +118,11 @@ def get_report_output(country_id: str, report_id: int) -> Response: """ Get a report output record by ID. + The response result may include requested_at, started_at, and finished_at + values projected from the selected report_output_runs row. Those fields are + base report execution metadata, not user-specific user-report association + last-run metadata. + Args: country_id (str): The country ID. report_id (int): The report output ID. @@ -155,7 +160,7 @@ def update_report_output(country_id: str) -> Response: Request body can contain: - id (int): The report output ID. - - status (str): The new status ('complete' or 'error') + - status (str): The new status ('pending', 'running', 'complete', or 'error') - output (dict): The result output (for complete status) - api_version (str): The API version of the report - error_message (str): The error message (for error status) @@ -173,19 +178,23 @@ def update_report_output(country_id: str) -> Response: print(f"Updating report #{report_id} for country {country_id}") # Validate status if provided - if status is not None and status not in ["pending", "complete", "error"]: - raise BadRequest("status must be 'pending', 'complete', or 'error'") + if status is not None and status not in [ + "pending", + "running", + "complete", + "error", + ]: + raise BadRequest("status must be 'pending', 'running', 'complete', or 'error'") # Validate that complete status has output if status == "complete" and output is None: raise BadRequest("output is required when status is 'complete'") try: - # First check if the report output exists - existing_report = report_output_service.get_stored_report_output( - country_id, report_id - ) - if existing_report is None: + # First check if the report output exists without running pointer sync: + # syncing a completed parent before this mutation can clear an active + # pending rerun that this PATCH is about to mark as running. + if not report_output_service.report_output_exists(country_id, report_id): raise NotFound(f"Report #{report_id} not found.") # Update the report output diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index c462257c8..38b5704fa 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime, timezone from sqlalchemy.engine.row import Row @@ -12,6 +13,8 @@ from policyengine_api.services.run_sync_utils import ( determine_parent_pointers, parse_json_field, + run_matches_report_result, + select_display_report_run, serialize_json_field, ) from policyengine_api.services.simulation_service import SimulationService @@ -25,6 +28,48 @@ def __init__(self): def _lock_clause(self) -> str: return "" if database.local else " FOR UPDATE" + def _utc_timestamp(self) -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + + def _format_run_timestamp(self, value) -> str | None: + if value is None: + return None + + if isinstance(value, datetime): + timestamp = value + if timestamp.tzinfo is None: + timestamp = timestamp.replace(tzinfo=timezone.utc) + return ( + timestamp.astimezone(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + + timestamp = str(value).strip() + if not timestamp: + return None + + normalized = timestamp.replace(" ", "T", 1) + parseable_timestamp = ( + f"{normalized[:-1]}+00:00" if normalized.endswith("Z") else normalized + ) + try: + parsed = datetime.fromisoformat(parseable_timestamp) + except ValueError: + if "T" in normalized: + return normalized if normalized.endswith("Z") else f"{normalized}Z" + return f"{normalized}Z" + + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return ( + parsed.astimezone(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + def _get_report_output_row( self, report_output_id: int, @@ -137,12 +182,84 @@ def _select_mutable_run( self, report_output: dict, runs_descending: list[dict] ) -> dict | None: active_run_id = report_output.get("active_run_id") + if report_output["status"] == "running": + if active_run_id is not None: + for run in runs_descending: + if run["id"] == active_run_id and run["status"] in ( + "pending", + "running", + ): + return run + for run in runs_descending: + if run["status"] in ("pending", "running"): + return run + return None if active_run_id is not None: for run in runs_descending: if run["id"] == active_run_id: return run return runs_descending[0] if runs_descending else None + def _has_mutable_running_run(self, report_output: dict, *, queryer=None) -> bool: + runs_descending = self._list_report_runs_descending( + report_output["id"], queryer=queryer + ) + if not runs_descending: + return True + + active_run_id = report_output.get("active_run_id") + if active_run_id is not None: + for run in runs_descending: + if run["id"] == active_run_id: + return run["status"] in ("pending", "running") + return False + + return any(run["status"] in ("pending", "running") for run in runs_descending) + + def _run_needs_timestamp_sync(self, run: dict, status: str) -> bool: + if run.get("requested_at") is None: + return True + if status in ("complete", "error"): + return run.get("started_at") is None or run.get("finished_at") is None + if status == "running": + return run.get("started_at") is None or run.get("finished_at") is not None + return run.get("started_at") is not None or run.get("finished_at") is not None + + def _with_display_run_timestamps( + self, report_output: dict, *, queryer=None + ) -> dict: + """ + Overlay selected run timestamps onto the legacy report response shape. + + This is a response-compatibility bridge for app-v2 while report output + reads still return a report_outputs row. The authoritative timestamp + values live on report_output_runs; this helper chooses the display run, + formats its requested/started/finished timestamps, and returns an + enriched copy of the report output dict. It intentionally does not + mutate database state. + + These timestamps describe the selected base report execution. They are + not user-report association metadata and should not be treated as a + user-specific "last run" value. + + TODO: When report output reads are cut over to canonical run-backed + resolution, move this projection into the final response serializer + instead of keeping it as an ad hoc enrichment helper. + """ + runs_descending = self._list_report_runs_descending( + report_output["id"], queryer=queryer + ) + display_run = select_display_report_run(report_output, runs_descending) + if display_run is None: + return report_output + + enriched_report_output = dict(report_output) + for field in ("requested_at", "started_at", "finished_at"): + enriched_report_output[field] = self._format_run_timestamp( + display_run.get(field) + ) + return enriched_report_output + def _derive_report_country_package_version( self, simulation_1: dict | None, @@ -281,6 +398,12 @@ def _insert_bootstrap_report_run( report_spec: ReportSpec | None, version_manifest: dict[str, str | None], ) -> None: + requested_at = self._utc_timestamp() + is_terminal = report_output["status"] in ("complete", "error") + has_started = report_output["status"] in ("running", "complete", "error") + started_at = requested_at if has_started else None + finished_at = requested_at if is_terminal else None + tx.query( """ INSERT INTO report_output_runs ( @@ -300,9 +423,9 @@ def _insert_bootstrap_report_run( serialize_json_field(report_output.get("output")), report_output.get("error_message"), "initial", - None, - None, - None, + requested_at, + started_at, + finished_at, None, (report_spec.model_dump_json() if report_spec is not None else None), version_manifest["country_package_version"], @@ -324,11 +447,41 @@ def _update_report_run_in_transaction( report_output: dict, report_spec: ReportSpec | None, version_manifest: dict[str, str | None], + preserve_terminal_finished_at: bool = False, ) -> None: + fallback_timestamp = self._utc_timestamp() + timestamp_updates = [ + "requested_at = COALESCE(requested_at, started_at, finished_at, ?)" + ] + timestamp_values = [fallback_timestamp] + if report_output["status"] in ("complete", "error"): + finished_at = self._utc_timestamp() + timestamp_updates.append( + "started_at = COALESCE(started_at, finished_at, requested_at, ?)" + ) + timestamp_values.append(finished_at) + if preserve_terminal_finished_at: + timestamp_updates.append("finished_at = COALESCE(finished_at, ?)") + else: + timestamp_updates.append("finished_at = ?") + timestamp_values.append(finished_at) + elif report_output["status"] == "running": + started_at = self._utc_timestamp() + timestamp_updates.extend( + [ + "started_at = COALESCE(started_at, requested_at, ?)", + "finished_at = NULL", + ] + ) + timestamp_values.append(started_at) + else: + timestamp_updates.extend(["started_at = NULL", "finished_at = NULL"]) + tx.query( - """ + f""" UPDATE report_output_runs SET status = ?, output = ?, error_message = ?, + {", ".join(timestamp_updates)}, report_spec_snapshot_json = ?, country_package_version = ?, policyengine_version = ?, data_version = ?, runtime_app_name = ?, report_cache_version = ?, simulation_cache_version = ?, @@ -340,6 +493,7 @@ def _update_report_run_in_transaction( report_output["status"], serialize_json_field(report_output.get("output")), report_output.get("error_message"), + *timestamp_values, (report_spec.model_dump_json() if report_spec is not None else None), version_manifest["country_package_version"], version_manifest["policyengine_version"], @@ -441,22 +595,31 @@ def _ensure_report_output_dual_write_state_in_transaction( ) else: mutable_run = self._select_mutable_run(report_output, runs_descending) - if mutable_run is not None and not self._run_matches_parent( - mutable_run, - report_output, - report_spec, - version_manifest, - ): - self._update_report_run_in_transaction( - tx, - run_id=mutable_run["id"], - report_output=report_output, - report_spec=report_spec, - version_manifest=version_manifest, + if mutable_run is not None: + run_matches_parent = self._run_matches_parent( + mutable_run, + report_output, + report_spec, + version_manifest, ) - runs_descending = self._list_report_runs_descending( - report_output_id, queryer=tx + needs_timestamp_sync = self._run_needs_timestamp_sync( + mutable_run, report_output["status"] ) + if not run_matches_parent or needs_timestamp_sync: + run_matches_result = run_matches_report_result( + mutable_run, report_output + ) + self._update_report_run_in_transaction( + tx, + run_id=mutable_run["id"], + report_output=report_output, + report_spec=report_spec, + version_manifest=version_manifest, + preserve_terminal_finished_at=run_matches_result, + ) + runs_descending = self._list_report_runs_descending( + report_output_id, queryer=tx + ) self._sync_parent_pointers_in_transaction(tx, report_output, runs_descending) refreshed_report_output = self._get_report_output_row( @@ -466,7 +629,7 @@ def _ensure_report_output_dual_write_state_in_transaction( ) if refreshed_report_output is None: raise ValueError(f"Report output #{report_output_id} not found after sync") - return refreshed_report_output + return self._with_display_run_timestamps(refreshed_report_output, queryer=tx) def ensure_report_output_dual_write_state( self, @@ -485,11 +648,31 @@ def get_stored_report_output( self, country_id: str, report_output_id: int ) -> dict | None: """ - Get the raw stored report output row by ID without aliasing to the - current runtime lineage. This is useful for mutation paths, which must - update the originally addressed row rather than a resolved alias. + Get a stored report output row without aliasing to current runtime lineage. + + This is used by mutation paths that must address the originally + requested row. It still runs dual-write synchronization, so it may + bootstrap or repair run/spec metadata and returns the display-run + timestamp projection. It is therefore not a raw database read. + + TODO: Split raw storage lookup from synchronized response projection in + a later run-backed read migration PR. """ - return self._get_report_output_row(report_output_id, country_id=country_id) + report_output = self._get_report_output_row( + report_output_id, country_id=country_id + ) + if report_output is None: + return None + return self.ensure_report_output_dual_write_state( + report_output_id, + country_id=country_id, + ) + + def report_output_exists(self, country_id: str, report_output_id: int) -> bool: + return ( + self._get_report_output_row(report_output_id, country_id=country_id) + is not None + ) def _is_current_report_output(self, report_output: dict) -> bool: return report_output.get("api_version") == get_report_output_cache_version( @@ -530,7 +713,7 @@ def _get_or_create_current_report_output(self, report_output: dict) -> dict: year=report_output["year"], ) if current_report is not None: - return current_report + return self._with_display_run_timestamps(current_report) return self.create_report_output( country_id=report_output["country_id"], @@ -565,7 +748,11 @@ def find_existing_report_output( ) if existing_report is not None: print(f"Found existing report output with ID: {existing_report['id']}") - return existing_report + return self.ensure_report_output_dual_write_state( + existing_report["id"], + country_id=country_id, + ) + return None except Exception as e: print(f"Error checking for existing report output. Details: {str(e)}") @@ -691,7 +878,10 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No return None if self._is_current_report_output(report_output): - return report_output + return self.ensure_report_output_dual_write_state( + report_output_id, + country_id=country_id, + ) current_report = self._get_or_create_current_report_output(report_output) return self._alias_report_output(report_output_id, current_report) @@ -745,6 +935,14 @@ def tx_callback(tx): if requested_report is None: raise ValueError(f"Report output #{report_id} not found") + if status == "running" and not self._has_mutable_running_run( + requested_report, queryer=tx + ): + raise ValueError( + "Cannot mark report output running without an active " + "pending or running report run" + ) + tx.query( f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", (*update_values, report_id, country_id), diff --git a/policyengine_api/services/report_run_service.py b/policyengine_api/services/report_run_service.py index 2646526f4..9899f6cc9 100644 --- a/policyengine_api/services/report_run_service.py +++ b/policyengine_api/services/report_run_service.py @@ -1,10 +1,12 @@ import json import uuid +from datetime import datetime, timezone from typing import Any from sqlalchemy.engine.row import Row from policyengine_api.data import database +from policyengine_api.services.run_sync_utils import select_display_report_run REPORT_RUN_VERSION_FIELDS = ( @@ -21,6 +23,9 @@ class ReportRunService: + def _utc_timestamp(self) -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + def _serialize_json( self, value: dict[str, Any] | list[Any] | str | None ) -> str | None: @@ -77,6 +82,12 @@ def create_run_transaction(tx) -> None: else 1 ) + requested_at = self._utc_timestamp() + is_terminal = status in ("complete", "error") + has_started = status in ("running", "complete", "error") + started_at = requested_at if has_started else None + finished_at = requested_at if is_terminal else None + tx.query( f""" INSERT INTO report_output_runs ( @@ -93,9 +104,9 @@ def create_run_transaction(tx) -> None: self._serialize_json(output), error_message, trigger_type, - None, - None, - None, + requested_at, + started_at, + finished_at, source_run_id, self._serialize_json(report_spec_snapshot), *[ @@ -139,14 +150,7 @@ def get_newest_report_output_run(self, report_output_id: int) -> dict | None: return self._parse_run_row(row) def select_display_run(self, report_output: dict) -> dict | None: - if report_output.get("active_run_id"): - active_run = self.get_report_output_run(report_output["active_run_id"]) - if active_run is not None: - return active_run - if report_output.get("latest_successful_run_id"): - latest_successful_run = self.get_report_output_run( - report_output["latest_successful_run_id"] - ) - if latest_successful_run is not None: - return latest_successful_run - return self.get_newest_report_output_run(report_output["id"]) + runs_descending = list( + reversed(self.list_report_output_runs(report_output["id"])) + ) + return select_display_report_run(report_output, runs_descending) diff --git a/policyengine_api/services/run_sync_utils.py b/policyengine_api/services/run_sync_utils.py index 7de888cde..56220fe36 100644 --- a/policyengine_api/services/run_sync_utils.py +++ b/policyengine_api/services/run_sync_utils.py @@ -21,6 +21,41 @@ def get_latest_successful_run_id(runs: list[dict]) -> str | None: return None +def run_matches_report_result(run: dict, report_output: dict) -> bool: + return ( + run["status"] == report_output["status"] + and run.get("output") == report_output.get("output") + and run.get("error_message") == report_output.get("error_message") + ) + + +def select_display_report_run( + report_output: dict, runs_descending: list[dict] +) -> dict | None: + active_run_id = report_output.get("active_run_id") + if active_run_id is not None: + for run in runs_descending: + if run["id"] == active_run_id: + return run + + if report_output["status"] == "error": + for run in runs_descending: + if run_matches_report_result(run, report_output): + return run + + latest_successful_run_id = report_output.get("latest_successful_run_id") + if latest_successful_run_id is not None: + for run in runs_descending: + if run["id"] == latest_successful_run_id: + return run + + for run in runs_descending: + if run_matches_report_result(run, report_output): + return run + + return runs_descending[0] if runs_descending else None + + def determine_parent_pointers( status: str, runs_descending: list[dict] ) -> tuple[str | None, str | None]: diff --git a/tests/to_refactor/python/test_yearly_var_removal.py b/tests/to_refactor/python/test_yearly_var_removal.py deleted file mode 100644 index 875176fe4..000000000 --- a/tests/to_refactor/python/test_yearly_var_removal.py +++ /dev/null @@ -1,330 +0,0 @@ -import pytest -import json -import uuid - -from policyengine_api.endpoints.household import get_household_under_policy -from policyengine_api.services.metadata_service import MetadataService -from policyengine_api.services.policy_service import PolicyService -from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS -from policyengine_api.data import database -from policyengine_api.api import app - -policy_service = PolicyService() -metadata_service = MetadataService() - - -@pytest.fixture -def client(): - app.config["TESTING"] = True - with app.test_client() as client: - yield client - - -def make_test_household_id() -> str: - # Use a negative signed 32-bit-ish integer string to avoid colliding with - # normal autoincrement rows while remaining compatible with INT columns. - return str(-((uuid.uuid4().int % 2_000_000_000) or 1)) - - -def create_test_household(household_id, country_id): - test_household = None - - row = database.query( - f"SELECT * FROM household WHERE id = ? AND country_id = ?", - (household_id, country_id), - ).fetchone() - - if row is not None: - # WARNING: This could mutate existing arrays if running make-test - # instead of make debug-test specifically on production server - remove_test_household(household_id, country_id) - - with open( - f"./tests/data/{country_id}_household.json", - "r", - encoding="utf-8", - ) as f: - test_household = json.load(f) - - try: - row = database.query( - f"INSERT INTO household (id, country_id, household_json, household_hash, label, api_version) VALUES (?, ?, ?, ?, ?, ?)", - ( - household_id, - country_id, - json.dumps(test_household), - "Garbage value", - "Garbage value", - "0.0.0", - ), - ) - - except Exception as err: - raise err - - return household_id - - -def remove_test_household(household_id, country_id): - row = database.query( - f"SELECT * FROM household WHERE id = ? AND country_id = ?", - (household_id, country_id), - ).fetchone() - - if row is not None: - try: - database.query( - f"DELETE FROM household WHERE id = ? AND country_id = ?", - (household_id, country_id), - ) - except Exception as err: - raise err - - return True - - -def remove_calculated_hup(household_id, policy_id, country_id): - """ - Function to remove the calculated household under policy generated - by get_household_under_policy, for testing purposes - """ - - api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - - try: - database.query( - f"DELETE FROM computed_household WHERE household_id = ? AND policy_id = ? AND api_version = ?", - (household_id, policy_id, api_version), - ) - except Exception as err: - raise err - - -def interface_test_household_under_policy( - country_id: str, current_law: str, excluded_vars: list -): - """ - Test that a household under current law contains all relevant - """ - # Note: Attempted to mock the database.query statements in get_household_under_policy, - # but was unable to, hence the (less secure) emission of SQL creation, followed by deletion - CURRENT_LAW = current_law - - # Value to invalidated if any key is not present in household - is_test_passing = True - - test_household_id = make_test_household_id() - - # Fetch live country metadata - metadata = metadata_service.get_metadata(country_id) - - try: - # Create the test household on the local db instance - create_test_household(test_household_id, country_id) - - # Create a result object by simply calling the relevant function - result_object = get_household_under_policy( - country_id, test_household_id, CURRENT_LAW - )["result"] - finally: - remove_test_household(test_household_id, country_id) - remove_calculated_hup(test_household_id, CURRENT_LAW, country_id) - - # Create a dict of entity singular and plural terms for testing - entities_map = {} - for entity in metadata["entities"]: - entity_plural = metadata["entities"][entity]["plural"] - entities_map[entity_plural] = entity - - # Create a set of all variables listed within the metadata that are yearly, - # as well as one that will store all variables accessed while looping - # Note: This removes issues with SNAP variables, which are calculated monthly - var_filter = lambda x: ( - (metadata["variables"][x]["definitionPeriod"] == "year") - and x not in excluded_vars - ) - metadata_var_set = set(filter(var_filter, metadata["variables"].keys())) - result_var_set = set() - - # Loop through every third-level variable in result_object - for entity_group in result_object: - for entity in result_object[entity_group]: - entity_group_singularized = entities_map[entity_group] - for variable in result_object[entity_group][entity]: - # Skip ignored variables - if ( - variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] != "year" - ): - continue - - # Ensure that the variable exists in both - # result_object and test_object - if variable not in metadata["variables"]: - print(f"Failing due to variable {variable} not in metadata") - is_test_passing = False - break - - # Ensure that variable exists within the correct - # entity - if ( - variable not in excluded_vars - and entity_group_singularized - != metadata["variables"][variable]["entity"] - ): - print( - f"Failing due to variable {variable} not in entity group {entity_group_singularized}" - ) - is_test_passing = False - break - - # Add variable to result var set - result_var_set.add(variable) - - if result_var_set != metadata_var_set: - results_diff = result_var_set.difference(metadata_var_set) - metadata_diff = metadata_var_set.difference(result_var_set) - if len(results_diff) > 0: - print("Error: The following values are only present in the result object:") - print(results_diff) - if len(metadata_diff) > 0: - print("Error: The following values are only present in the metadata:") - print(metadata_diff) - is_test_passing = False - - return is_test_passing - - -def test_make_test_household_id_returns_negative_integer_string(): - test_household_id = make_test_household_id() - - assert test_household_id.startswith("-") - assert int(test_household_id) < 0 - - -def test_make_test_household_id_is_unique(): - generated_ids = {make_test_household_id() for _ in range(100)} - - assert len(generated_ids) == 100 - - -def test_us_household_under_policy(): - """ - Test that a US household under current law is created correctly - """ - - is_test_passing = interface_test_household_under_policy("us", "2", ["members"]) - - assert is_test_passing == True - - -def test_uk_household_under_policy(): - """ - Test that a UK household under current law is created correctly - """ - - # The extra excluded variables all contain OpenFisca State entities, - # necessitating their removal - is_test_passing = interface_test_household_under_policy( - "uk", - "1", - ["members", "property_sale_rate", "state_id", "state_weight"], - ) - - assert is_test_passing == True - - -def test_get_calculate(client): - """ - Test the get_calculate endpoint with the same data as - test_us_household_under_policy. Note that redis must be running - for this test to function properly. - """ - - CURRENT_LAW_US = "2" - COUNTRY_ID = "us" - - test_household = None - test_object = {} - is_test_passing = True - - excluded_vars = ["members"] - - # Fetch live country metadata - metadata = metadata_service.get_metadata(COUNTRY_ID) - - with open(f"./tests/data/us_household.json", "r", encoding="utf-8") as f: - test_household = json.load(f) - - # Current law is represented by empty dict/empty JSON - test_policy = {} - - test_object["policy"] = test_policy - test_object["household"] = test_household - - res = client.post("/us/calculate-full", json=test_object) - result_object = json.loads(res.text)["result"] - - # Create a dict of entity singular and plural terms for testing - entities_map = {} - for entity in metadata["entities"]: - entity_plural = metadata["entities"][entity]["plural"] - entities_map[entity_plural] = entity - - # Create a set of all variables listed within the metadata that are yearly, - # as well as one that will store all variables accessed while looping - # Note: This removes issues with SNAP variables, which are calculated monthly - var_filter = lambda x: ( - (metadata["variables"][x]["definitionPeriod"] == "year") - and x not in excluded_vars - ) - metadata_var_set = set(filter(var_filter, metadata["variables"].keys())) - result_var_set = set() - - # Loop through every third-level variable in result_object - for entity_group in result_object: - for entity in result_object[entity_group]: - entity_group_singularized = entities_map[entity_group] - for variable in result_object[entity_group][entity]: - # Skip ignored variables - if ( - variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] != "year" - ): - continue - - # Ensure that the variable exists in both - # result_object and test_object - if variable not in metadata["variables"]: - print(f"Failing due to variable {variable} not in metadata") - is_test_passing = False - break - - # Ensure that variable exists within the correct - # entity - if ( - variable not in excluded_vars - and entity_group_singularized - != metadata["variables"][variable]["entity"] - ): - print( - f"Failing due to variable {variable} not in entity group {entity_group_singularized}" - ) - is_test_passing = False - break - - # Add variable to result var set - result_var_set.add(variable) - - if result_var_set != metadata_var_set: - results_diff = result_var_set.difference(metadata_var_set) - metadata_diff = metadata_var_set.difference(result_var_set) - if len(results_diff) > 0: - print("Error: The following values are only present in the result object:") - print(results_diff) - if len(metadata_diff) > 0: - print("Error: The following values are only present in the metadata:") - print(metadata_diff) - is_test_passing = False - - assert is_test_passing == True diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index c1b6709a5..55ee2ff62 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1,8 +1,11 @@ import pytest import json +from datetime import datetime, timezone from policyengine_api.constants import get_report_output_cache_version from policyengine_api.services.report_output_service import ReportOutputService +from policyengine_api.services.report_run_service import ReportRunService +from policyengine_api.services.run_sync_utils import select_display_report_run from policyengine_api.services.simulation_service import SimulationService from tests.fixtures.services import report_output_fixtures @@ -10,9 +13,69 @@ pytest_plugins = ("tests.fixtures.services.report_output_fixtures",) service = ReportOutputService() +report_run_service = ReportRunService() simulation_service = SimulationService() +class TestReportOutputRunTimestamps: + def test_format_run_timestamp_handles_supported_values(self): + assert ( + service._format_run_timestamp(datetime(2026, 5, 4, 12, 0, 0)) + == "2026-05-04T12:00:00Z" + ) + assert ( + service._format_run_timestamp( + datetime(2026, 5, 4, 12, 0, 0, tzinfo=timezone.utc) + ) + == "2026-05-04T12:00:00Z" + ) + assert service._format_run_timestamp("") is None + assert ( + service._format_run_timestamp("2026-05-04T12:00:00") + == "2026-05-04T12:00:00Z" + ) + assert ( + service._format_run_timestamp("2026-05-04T12:00:00Z") + == "2026-05-04T12:00:00Z" + ) + assert ( + service._format_run_timestamp("2026-05-04T12:00:00+01:00") + == "2026-05-04T11:00:00Z" + ) + assert ( + service._format_run_timestamp("2026-05-04 12:00:00.123456") + == "2026-05-04T12:00:00Z" + ) + + def test_select_display_run_uses_matching_result_before_newest_fallback(self): + report_output = { + "id": 1, + "status": "complete", + "output": '{"ok": true}', + "error_message": None, + "active_run_id": None, + "latest_successful_run_id": None, + } + matching_run = { + "id": "matching", + "status": "complete", + "output": '{"ok": true}', + "error_message": None, + } + newest_non_matching_run = { + "id": "newest", + "status": "pending", + "output": None, + "error_message": None, + } + + selected_run = select_display_report_run( + report_output, [newest_non_matching_run, matching_run] + ) + + assert selected_run["id"] == "matching" + + class TestFindExistingReportOutput: """Test finding existing report outputs in the database.""" @@ -333,6 +396,10 @@ def test_create_report_output_populates_dual_write_state(self, test_db): assert run is not None assert run["status"] == "pending" assert run["trigger_type"] == "initial" + assert run["requested_at"] is not None + assert created_report["requested_at"] is not None + assert created_report["started_at"] is None + assert created_report["finished_at"] is None snapshot = run["report_spec_snapshot_json"] if isinstance(snapshot, str): snapshot = json.loads(snapshot) @@ -500,6 +567,626 @@ def test_get_report_output_with_json_output(self, test_db): assert result["year"] == "2025" # Frontend will parse this string + def test_get_report_output_includes_display_run_timestamps(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_display_timestamps", + population_type="household", + policy_id=40, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = ?, started_at = ?, finished_at = ? + WHERE id = ? + """, + ( + "2026-05-04 12:00:00", + "2026-05-04 12:01:00", + "2026-05-04 12:02:00", + run["id"], + ), + ) + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["requested_at"] == "2026-05-04T12:00:00Z" + assert result["started_at"] == "2026-05-04T12:01:00Z" + assert result["finished_at"] == "2026-05-04T12:02:00Z" + + def test_update_report_output_sets_finished_at_on_display_run(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_finished_timestamp", + population_type="household", + policy_id=41, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + success = service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + + assert success is True + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + assert result["status"] == "complete" + assert result["requested_at"] is not None + assert result["started_at"] is not None + assert result["finished_at"] is not None + + def test_error_rerun_uses_error_run_timestamp_over_previous_success(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_error_rerun_timestamp", + population_type="household", + policy_id=42, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + completed_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + previous_success_id = completed_report["latest_successful_run_id"] + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = ?, started_at = ?, finished_at = ? + WHERE id = ? + """, + ( + "2026-05-04 10:00:00", + "2026-05-04 10:01:00", + "2026-05-04 10:02:00", + previous_success_id, + ), + ) + rerun = report_run_service.create_report_output_run( + report["id"], trigger_type="rerun" + ) + test_db.query( + """ + UPDATE report_outputs + SET active_run_id = ?, latest_successful_run_id = ? + WHERE id = ? + """, + (rerun["id"], previous_success_id, report["id"]), + ) + + service.update_report_output( + country_id="us", + report_id=report["id"], + status="error", + error_message="rerun failed", + ) + + updated_rerun = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["status"] == "error" + assert result["finished_at"] == service._format_run_timestamp( + updated_rerun["finished_at"] + ) + assert result["finished_at"] != "2026-05-04T10:02:00Z" + + def test_pending_update_clears_terminal_display_timestamps(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_pending_timestamp_reset", + population_type="household", + policy_id=43, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + + service.update_report_output( + country_id="us", + report_id=report["id"], + status="pending", + ) + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["status"] == "pending" + assert result["requested_at"] is not None + assert result["started_at"] is None + assert result["finished_at"] is None + + def test_running_update_sets_started_at_without_finished_at(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_running_timestamp", + population_type="household", + policy_id=44, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + service.update_report_output( + country_id="us", + report_id=report["id"], + status="running", + ) + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["status"] == "running" + assert result["requested_at"] is not None + assert result["started_at"] is not None + assert result["finished_at"] is None + + def test_running_update_requires_non_terminal_run_after_success(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_running_without_active_run", + population_type="household", + policy_id=45, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + completed_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + successful_run_id = completed_report["latest_successful_run_id"] + + with pytest.raises(ValueError, match="active pending or running"): + service.update_report_output( + country_id="us", + report_id=report["id"], + status="running", + ) + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + successful_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (successful_run_id,), + ).fetchone() + assert stored_report["status"] == "complete" + assert stored_report["active_run_id"] is None + assert stored_report["latest_successful_run_id"] == successful_run_id + assert successful_run["status"] == "complete" + assert successful_run["output"] == json.dumps({"ok": True}) + assert successful_run["finished_at"] is not None + + def test_running_update_uses_active_rerun_without_rewriting_success(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_active_running_rerun", + population_type="household", + policy_id=46, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + completed_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + successful_run_id = completed_report["latest_successful_run_id"] + rerun = report_run_service.create_report_output_run( + report["id"], trigger_type="rerun" + ) + test_db.query( + """ + UPDATE report_outputs + SET active_run_id = ?, latest_successful_run_id = ? + WHERE id = ? + """, + (rerun["id"], successful_run_id, report["id"]), + ) + + service.update_report_output( + country_id="us", + report_id=report["id"], + status="running", + ) + + successful_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (successful_run_id,), + ).fetchone() + active_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert successful_run["status"] == "complete" + assert successful_run["finished_at"] is not None + assert active_run["status"] == "running" + assert active_run["started_at"] is not None + assert active_run["finished_at"] is None + assert stored_report["active_run_id"] == rerun["id"] + assert stored_report["latest_successful_run_id"] == successful_run_id + + def test_get_report_output_does_not_rewrite_terminal_active_run_for_running_parent( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_running_bad_active_run", + population_type="household", + policy_id=47, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + output_json = json.dumps({"ok": True}) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=output_json, + ) + completed_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + successful_run_id = completed_report["latest_successful_run_id"] + test_db.query( + """ + UPDATE report_outputs + SET status = ?, active_run_id = ?, latest_successful_run_id = ? + WHERE id = ? + """, + ("running", successful_run_id, successful_run_id, report["id"]), + ) + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + successful_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (successful_run_id,), + ).fetchone() + assert result["status"] == "running" + assert successful_run["status"] == "complete" + assert successful_run["output"] == output_json + assert successful_run["finished_at"] is not None + + def test_get_stored_report_output_includes_display_run_timestamps(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_stored_timestamp", + population_type="household", + policy_id=48, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + result = service.get_stored_report_output("us", report["id"]) + + assert result is not None + assert result["requested_at"] is not None + assert result["started_at"] is None + assert result["finished_at"] is None + + def test_get_stored_report_output_returns_none_when_missing(self, test_db): + assert service.get_stored_report_output("us", 999999) is None + + def test_get_report_output_backfills_missing_timestamps_on_matching_legacy_run( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_legacy_timestamp_get", + population_type="household", + policy_id=46, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = NULL, started_at = NULL, finished_at = NULL + WHERE report_output_id = ? + """, + (report["id"],), + ) + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["requested_at"] is not None + assert result["started_at"] is not None + assert result["finished_at"] is not None + stored_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + assert stored_run["requested_at"] is not None + assert stored_run["started_at"] is not None + assert stored_run["finished_at"] is not None + + def test_get_report_output_preserves_existing_finished_at_during_backfill( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_legacy_finished_at", + population_type="household", + policy_id=47, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = NULL, started_at = ?, finished_at = ? + WHERE report_output_id = ? + """, + ( + "2026-05-04 12:01:00", + "2026-05-04 12:02:00", + report["id"], + ), + ) + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["requested_at"] == "2026-05-04T12:01:00Z" + assert result["started_at"] == "2026-05-04T12:01:00Z" + assert result["finished_at"] == "2026-05-04T12:02:00Z" + stored_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + assert service._format_run_timestamp(stored_run["requested_at"]) == ( + "2026-05-04T12:01:00Z" + ) + assert service._format_run_timestamp(stored_run["finished_at"]) == ( + "2026-05-04T12:02:00Z" + ) + + def test_get_report_output_preserves_finished_at_while_backfilling_metadata( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_legacy_finished_metadata", + population_type="household", + policy_id=48, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = NULL, + started_at = ?, + finished_at = ?, + report_spec_snapshot_json = NULL, + country_package_version = NULL + WHERE report_output_id = ? + """, + ( + "2026-05-04 12:01:00", + "2026-05-04 12:02:00", + report["id"], + ), + ) + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["requested_at"] == "2026-05-04T12:01:00Z" + assert result["started_at"] == "2026-05-04T12:01:00Z" + assert result["finished_at"] == "2026-05-04T12:02:00Z" + stored_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + assert stored_run["report_spec_snapshot_json"] is not None + assert stored_run["country_package_version"] is not None + assert service._format_run_timestamp(stored_run["requested_at"]) == ( + "2026-05-04T12:01:00Z" + ) + assert service._format_run_timestamp(stored_run["finished_at"]) == ( + "2026-05-04T12:02:00Z" + ) + + def test_get_report_output_bootstraps_running_legacy_run_started_at(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_legacy_running", + population_type="household", + policy_id=49, + ) + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, status, api_version, year + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + "us", + simulation["id"], + None, + "running", + get_report_output_cache_version("us"), + "2025", + ), + ) + report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + result = service.get_report_output( + country_id="us", report_output_id=report["id"] + ) + + assert result["status"] == "running" + assert result["requested_at"] is not None + assert result["started_at"] is not None + assert result["finished_at"] is None + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + assert run["status"] == "running" + assert run["started_at"] is not None + assert run["finished_at"] is None + + def test_find_existing_report_output_backfills_missing_timestamps(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_report_legacy_timestamp_find", + population_type="household", + policy_id=50, + ) + report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = NULL + WHERE report_output_id = ? + """, + (report["id"],), + ) + + result = service.find_existing_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + assert result is not None + assert result["requested_at"] is not None + def test_get_report_output_resolves_stale_id_to_current_runtime_row(self, test_db): stale_output = { "budget": {"budgetary_impact": 1}, diff --git a/tests/unit/services/test_report_run_service.py b/tests/unit/services/test_report_run_service.py index b286dbee3..68aedb48f 100644 --- a/tests/unit/services/test_report_run_service.py +++ b/tests/unit/services/test_report_run_service.py @@ -40,6 +40,9 @@ def test_creates_report_runs_with_incrementing_sequence(self, test_db): assert first_run["run_sequence"] == 2 assert first_run["trigger_type"] == "initial" + assert first_run["requested_at"] is not None + assert first_run["started_at"] is None + assert first_run["finished_at"] is None assert first_run["report_spec_snapshot_json"] == {"country_id": "us"} assert first_run["country_package_version"] == "us-1.0.0" assert first_run["report_cache_version"] == "r123" @@ -100,6 +103,30 @@ def test_raises_when_parent_report_output_is_missing(self, test_db): assert "Report output #999999 not found" in str(exc_info.value) + def test_running_report_run_sets_started_at(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_running_run_timestamp", + population_type="household", + policy_id=8, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + run = report_run_service.create_report_output_run( + report_output["id"], + status="running", + trigger_type="rerun", + ) + + assert run["requested_at"] is not None + assert run["started_at"] is not None + assert run["finished_at"] is None + class TestSelectDisplayReportRun: def test_prefers_active_run(self, test_db): @@ -174,6 +201,48 @@ def test_falls_back_to_latest_successful_run(self, test_db): assert selected_run["id"] == successful_run["id"] + def test_prefers_matching_error_run_over_previous_success(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_5b", + population_type="household", + policy_id=5, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + successful_run = report_run_service.create_report_output_run( + report_output["id"], + status="complete", + trigger_type="initial", + output={"ok": True}, + ) + error_run = report_run_service.create_report_output_run( + report_output["id"], + status="error", + trigger_type="rerun", + error_message="rerun failed", + ) + test_db.query( + """ + UPDATE report_outputs + SET status = ?, error_message = ?, active_run_id = NULL, latest_successful_run_id = ? + WHERE id = ? + """, + ("error", "rerun failed", successful_run["id"], report_output["id"]), + ) + updated_report_output = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_output["id"],), + ).fetchone() + + selected_run = report_run_service.select_display_run(updated_report_output) + + assert selected_run["id"] == error_run["id"] + def test_falls_back_when_active_run_pointer_is_stale(self, test_db): simulation = simulation_service.create_simulation( country_id="us", @@ -278,3 +347,36 @@ def test_falls_back_to_newest_run_when_latest_successful_pointer_is_stale( selected_run = report_run_service.select_display_run(updated_report_output) assert selected_run["id"] == newest_run["id"] + + def test_falls_back_to_newest_run_when_no_pointer_or_result_match(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_6b", + population_type="household", + policy_id=6, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + newest_run = report_run_service.create_report_output_run( + report_output["id"], status="pending", trigger_type="rerun" + ) + test_db.query( + """ + UPDATE report_outputs + SET status = ?, active_run_id = NULL, latest_successful_run_id = NULL + WHERE id = ? + """, + ("complete", report_output["id"]), + ) + updated_report_output = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_output["id"],), + ).fetchone() + + selected_run = report_run_service.select_display_run(updated_report_output) + + assert selected_run["id"] == newest_run["id"] diff --git a/tests/unit/test_openapi_report_routes.py b/tests/unit/test_openapi_report_routes.py new file mode 100644 index 000000000..517c55d07 --- /dev/null +++ b/tests/unit/test_openapi_report_routes.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import yaml + + +def _load_spec() -> dict: + spec_path = Path(__file__).parents[2] / "policyengine_api" / "openapi_spec.yaml" + return yaml.safe_load(spec_path.read_text(encoding="utf-8")) + + +def _result_properties(response_schema: dict) -> dict: + return response_schema["content"]["application/json"]["schema"]["properties"][ + "result" + ]["properties"] + + +def test_report_output_openapi_responses_include_run_timestamp_fields(): + spec = _load_spec() + paths = spec["paths"] + timestamp_fields = {"requested_at", "started_at", "finished_at"} + + response_schemas = [ + paths["/{country_id}/report"]["post"]["responses"]["200"], + paths["/{country_id}/report"]["post"]["responses"]["201"], + paths["/{country_id}/report"]["patch"]["responses"]["200"], + paths["/{country_id}/report/{report_id}"]["get"]["responses"]["200"], + ] + + for response_schema in response_schemas: + result_properties = _result_properties(response_schema) + assert timestamp_fields.issubset(result_properties) + for field in timestamp_fields: + assert result_properties[field]["type"] == "string" + assert result_properties[field]["nullable"] is True diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 5d0ca4f79..ec9f34e1b 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -6,11 +6,13 @@ from policyengine_api.routes.report_output_routes import report_output_bp from policyengine_api.routes.simulation_routes import simulation_bp from policyengine_api.services.report_output_service import ReportOutputService +from policyengine_api.services.report_run_service import ReportRunService from policyengine_api.services.simulation_service import SimulationService simulation_service = SimulationService() report_output_service = ReportOutputService() +report_run_service = ReportRunService() def create_test_client() -> Flask: @@ -120,6 +122,51 @@ def test_create_report_output_existing_row_repairs_dual_write_state(test_db): assert snapshot["report_kind"] == "household_single" +def test_post_report_output_returns_timestamp_fields_for_new_and_existing_report( + test_db, +): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_report_timestamps", + population_type="household", + policy_id=46, + ) + + client = create_test_client() + response = client.post( + "/us/report", + json={ + "simulation_1_id": simulation["id"], + "simulation_2_id": None, + "year": "2025", + }, + ) + + assert response.status_code == 201 + payload = response.get_json() + created_report = payload["result"] + assert created_report["requested_at"] is not None + assert created_report["started_at"] is None + assert created_report["finished_at"] is None + + existing_response = client.post( + "/us/report", + json={ + "simulation_1_id": simulation["id"], + "simulation_2_id": None, + "year": "2025", + }, + ) + + assert existing_response.status_code == 200 + existing_payload = existing_response.get_json() + existing_report = existing_payload["result"] + assert existing_report["id"] == created_report["id"] + assert existing_report["requested_at"] is not None + assert existing_report["started_at"] is None + assert existing_report["finished_at"] is None + + def test_create_report_output_missing_primary_simulation_returns_bad_request(test_db): client = create_test_client() response = client.post( @@ -264,3 +311,272 @@ def test_patch_report_output_wrong_country_returns_not_found_and_does_not_mutate assert stored_report["country_id"] == "us" assert stored_report["status"] == "pending" assert stored_report["output"] is None + + +def test_patch_report_output_accepts_running_status(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_running_report", + population_type="household", + policy_id=45, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "status": "running", + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["status"] == "running" + assert payload["result"]["requested_at"] is not None + assert payload["result"]["started_at"] is not None + assert payload["result"]["finished_at"] is None + + +def test_get_report_output_serializes_display_run_timestamps(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_get_timestamp", + population_type="household", + policy_id=47, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = ?, started_at = ?, finished_at = ? + WHERE id = ? + """, + ( + "2026-05-04 12:00:00", + "2026-05-04 12:01:00", + "2026-05-04 12:02:00", + run["id"], + ), + ) + + client = create_test_client() + response = client.get(f"/us/report/{report['id']}") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["requested_at"] == "2026-05-04T12:00:00Z" + assert payload["result"]["started_at"] == "2026-05-04T12:01:00Z" + assert payload["result"]["finished_at"] == "2026-05-04T12:02:00Z" + + +def test_patch_report_output_running_uses_active_rerun_route_path(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_active_running_rerun", + population_type="household", + policy_id=48, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + completed_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + successful_run_id = completed_report["latest_successful_run_id"] + rerun = report_run_service.create_report_output_run( + report["id"], trigger_type="rerun" + ) + test_db.query( + """ + UPDATE report_outputs + SET active_run_id = ?, latest_successful_run_id = ? + WHERE id = ? + """, + (rerun["id"], successful_run_id, report["id"]), + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "status": "running", + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["status"] == "running" + assert payload["result"]["started_at"] is not None + assert payload["result"]["finished_at"] is None + + successful_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (successful_run_id,), + ).fetchone() + active_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + assert successful_run["status"] == "complete" + assert successful_run["finished_at"] is not None + assert active_run["status"] == "running" + assert active_run["started_at"] is not None + assert active_run["finished_at"] is None + + +def test_patch_report_output_error_uses_active_rerun_timestamp_route_path(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_active_error_rerun", + population_type="household", + policy_id=49, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + completed_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + successful_run_id = completed_report["latest_successful_run_id"] + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = ?, started_at = ?, finished_at = ? + WHERE id = ? + """, + ( + "2026-05-04 10:00:00", + "2026-05-04 10:01:00", + "2026-05-04 10:02:00", + successful_run_id, + ), + ) + rerun = report_run_service.create_report_output_run( + report["id"], trigger_type="rerun" + ) + test_db.query( + """ + UPDATE report_outputs + SET active_run_id = ?, latest_successful_run_id = ? + WHERE id = ? + """, + (rerun["id"], successful_run_id, report["id"]), + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "status": "error", + "error_message": "rerun failed", + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["status"] == "error" + assert payload["result"]["finished_at"] is not None + assert payload["result"]["finished_at"] != "2026-05-04T10:02:00Z" + + +def test_patch_report_output_complete_promotes_active_rerun_route_path(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_active_complete_rerun", + population_type="household", + policy_id=50, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + completed_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + successful_run_id = completed_report["latest_successful_run_id"] + rerun = report_run_service.create_report_output_run( + report["id"], trigger_type="rerun" + ) + test_db.query( + """ + UPDATE report_outputs + SET active_run_id = ?, latest_successful_run_id = ? + WHERE id = ? + """, + (rerun["id"], successful_run_id, report["id"]), + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "status": "complete", + "output": json.dumps({"ok": "rerun"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["status"] == "complete" + assert payload["result"]["finished_at"] is not None + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert stored_report["active_run_id"] is None + assert stored_report["latest_successful_run_id"] == rerun["id"] diff --git a/tests/unit/test_yearly_var_removal.py b/tests/unit/test_yearly_var_removal.py new file mode 100644 index 000000000..c5d81009c --- /dev/null +++ b/tests/unit/test_yearly_var_removal.py @@ -0,0 +1,107 @@ +import copy +from types import SimpleNamespace + +from policyengine_api.endpoints.household import add_yearly_variables + + +TEST_YEAR = "2023" + + +def fake_country_metadata(): + return { + "entities": { + "person": {"plural": "people"}, + "household": {"plural": "households"}, + }, + "variables": { + "age": { + "definitionPeriod": "year", + "entity": "person", + "name": "age", + "isInputVariable": True, + "defaultValue": 0, + }, + "employment_income": { + "definitionPeriod": "year", + "entity": "person", + "name": "employment_income", + "isInputVariable": True, + "defaultValue": 0, + }, + "household_net_income": { + "definitionPeriod": "year", + "entity": "household", + "name": "household_net_income", + "isInputVariable": False, + "defaultValue": None, + }, + "monthly_benefit": { + "definitionPeriod": "month", + "entity": "person", + "name": "monthly_benefit", + "isInputVariable": False, + "defaultValue": None, + }, + "person_id": { + "definitionPeriod": "eternity", + "entity": "person", + "name": "person_id", + "isInputVariable": True, + "defaultValue": "", + }, + "daily_value": { + "definitionPeriod": "day", + "entity": "person", + "name": "daily_value", + "isInputVariable": False, + "defaultValue": None, + }, + }, + } + + +def fake_countries(): + return SimpleNamespace( + get=lambda country_id: SimpleNamespace(metadata=fake_country_metadata()) + ) + + +def test_add_yearly_variables_fills_missing_year_month_and_eternity_values(): + household = { + "people": { + "you": { + "age": {TEST_YEAR: 40}, + "employment_income": {TEST_YEAR: 10_000}, + } + }, + "households": {"your household": {"members": ["you"]}}, + } + + result = add_yearly_variables( + copy.deepcopy(household), "test", countries=fake_countries() + ) + + assert result["people"]["you"]["employment_income"] == {TEST_YEAR: 10_000} + assert result["people"]["you"]["monthly_benefit"] == {TEST_YEAR: None} + assert result["people"]["you"]["person_id"] == {TEST_YEAR: ""} + assert "daily_value" not in result["people"]["you"] + assert result["households"]["your household"]["household_net_income"] == { + TEST_YEAR: None + } + + +def test_add_yearly_variables_ignores_entities_missing_from_household(): + household = { + "people": { + "you": { + "age": {TEST_YEAR: 40}, + } + } + } + + result = add_yearly_variables( + copy.deepcopy(household), "test", countries=fake_countries() + ) + + assert "households" not in result + assert result["people"]["you"]["employment_income"] == {TEST_YEAR: 0} diff --git a/uv.lock b/uv.lock index 8bb11c5e4..480fe985f 100644 --- a/uv.lock +++ b/uv.lock @@ -2622,7 +2622,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/a0/f3/eeea7dab690e46cd9 [[package]] name = "policyengine-api" -version = "3.40.7" +version = "3.40.11" source = { editable = "." } dependencies = [ { name = "anthropic" },