From c4b6d1e4623ffdcd682e7820441d9a96988637ed Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 17 May 2026 15:29:12 -0400 Subject: [PATCH 1/2] Add policyengine.derivations: structured + narrated variable explanations For any variable on any Simulation, derive(simulation, variable, period) returns a Derivation with the pruned dependency tree and the same scalar the simulation would have returned. The Derivation can be: - rendered as indented text (trace_text()) - walked programmatically (TraceNode is a stable, frozen dataclass) - summarized by the deterministic top_level_contributions() helper - handed to narrate(derivation) for an optional LLM-generated narrative The narration path is in its own submodule with a lazy LiteLLM import, so importing policyengine.derivations costs nothing if you only want the deterministic structured tree. Motivation: policybench needed per-cell "how PolicyEngine derived this value" walkthroughs for its leaderboard's prediction-detail modal. The deterministic primitives are useful for any caller wanting "explain this result" surfaces (calculators, dashboards, papers), so they belong here rather than baked into a downstream consumer. --- changelog.d/derivations-narrative.added.md | 1 + src/policyengine/derivations/__init__.py | 35 ++++ src/policyengine/derivations/narrate.py | 124 ++++++++++++++ src/policyengine/derivations/trace.py | 189 +++++++++++++++++++++ tests/test_derivations.py | 171 +++++++++++++++++++ 5 files changed, 520 insertions(+) create mode 100644 changelog.d/derivations-narrative.added.md create mode 100644 src/policyengine/derivations/__init__.py create mode 100644 src/policyengine/derivations/narrate.py create mode 100644 src/policyengine/derivations/trace.py create mode 100644 tests/test_derivations.py diff --git a/changelog.d/derivations-narrative.added.md b/changelog.d/derivations-narrative.added.md new file mode 100644 index 00000000..bac28654 --- /dev/null +++ b/changelog.d/derivations-narrative.added.md @@ -0,0 +1 @@ +Add ``policyengine.derivations`` for per-variable computation explanations: ``derive(simulation, variable, period)`` returns a structured ``Derivation`` (with pruned trace and top-level contributions); ``narrate(derivation)`` optionally hands it to an LLM for a plain-prose walkthrough. diff --git a/src/policyengine/derivations/__init__.py b/src/policyengine/derivations/__init__.py new file mode 100644 index 00000000..11d3d748 --- /dev/null +++ b/src/policyengine/derivations/__init__.py @@ -0,0 +1,35 @@ +"""Derivations: structured + narrated explanations of one variable's value. + +A ``Derivation`` is the pruned, deterministic computation tree for a single +``(simulation, variable)`` pair. The tree is the same information OpenFisca +already records when ``simulation.trace`` is on, but presented as a stable +data class (independent of OpenFisca internals) so callers can: + +- print or serialize the structured tree (deterministic, free), +- pull out top-level contributions for charts or tables, and +- optionally hand the derivation to an LLM via :func:`narrate` for a plain-prose + walkthrough (the only step that requires a network call). + +This module deliberately separates the *deterministic* part of the explanation +(everything in ``Derivation``) from the *narration* (an external LLM call). A +caller can use one without the other. +""" + +from .narrate import narrate, narrate_async +from .trace import ( + Derivation, + TraceNode, + derive, + is_zero_value, + top_level_contributions, +) + +__all__ = [ + "Derivation", + "TraceNode", + "derive", + "is_zero_value", + "narrate", + "narrate_async", + "top_level_contributions", +] diff --git a/src/policyengine/derivations/narrate.py b/src/policyengine/derivations/narrate.py new file mode 100644 index 00000000..d5059fbc --- /dev/null +++ b/src/policyengine/derivations/narrate.py @@ -0,0 +1,124 @@ +"""LLM narration of a structured :class:`Derivation`. + +This is the only piece of the derivations API that makes a network call. It +is kept in its own module so that callers who only want the deterministic +structured tree don't drag a network/LLM dependency into the import graph. + +LiteLLM is imported lazily inside the call so that ``import +policyengine.derivations`` doesn't require any LLM credentials to succeed. +""" + +from __future__ import annotations + +from typing import Any + +from .trace import Derivation + +DEFAULT_MODEL = "claude-sonnet-4-6" +DEFAULT_MAX_TOKENS = 500 +DEFAULT_TEMPERATURE = 0.0 + + +def _build_prompt( + derivation: Derivation, + *, + country: str | None, + household_summary: str | None, + extra_context: str | None, + trace_max_depth: int, +) -> str: + header_lines = [ + "You are summarizing how PolicyEngine derived a single variable's value " + "for one household.", + "", + f"VARIABLE: {derivation.variable}", + f"PERIOD: {derivation.period}", + f"REFERENCE VALUE: {derivation.value}", + ] + if country: + header_lines.insert(3, f"COUNTRY: {country.upper()}") + if household_summary: + header_lines.append(f"HOUSEHOLD: {household_summary}") + if extra_context: + header_lines.append("") + header_lines.append(extra_context) + + trace_text = derivation.trace_text(max_depth=trace_max_depth) + instructions = ( + "Write a 3-5 sentence narrative explaining how PolicyEngine arrived at " + "this value. Reference the most important intermediate quantities by " + "name and amount. Be concrete and quantitative. Plain prose, no " + "headers, no bullet lists." + ) + return ( + "\n".join(header_lines) + "\n\nPolicyEngine computation trace " + "(indented dependency tree, non-zero nodes only):\n```\n" + + trace_text + + "\n```\n\n" + + instructions + + "\n" + ) + + +def narrate( + derivation: Derivation, + *, + country: str | None = None, + household_summary: str | None = None, + extra_context: str | None = None, + model: str = DEFAULT_MODEL, + max_tokens: int = DEFAULT_MAX_TOKENS, + temperature: float = DEFAULT_TEMPERATURE, + trace_max_depth: int = 8, +) -> str: + """Synchronously ask an LLM to narrate this derivation. + + Imports LiteLLM lazily so that ``import policyengine.derivations`` has no + LLM dependency. Returns the model's response text. + """ + import litellm # noqa: PLC0415 — lazy import keeps the base module light + + prompt = _build_prompt( + derivation, + country=country, + household_summary=household_summary, + extra_context=extra_context, + trace_max_depth=trace_max_depth, + ) + response = litellm.completion( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + max_tokens=max_tokens, + ) + return response.choices[0].message.content.strip() + + +async def narrate_async( + derivation: Derivation, + *, + country: str | None = None, + household_summary: str | None = None, + extra_context: str | None = None, + model: str = DEFAULT_MODEL, + max_tokens: int = DEFAULT_MAX_TOKENS, + temperature: float = DEFAULT_TEMPERATURE, + trace_max_depth: int = 8, +) -> str: + """Async variant of :func:`narrate` — same interface, awaitable result.""" + import litellm # noqa: PLC0415 — lazy import keeps the base module light + + prompt = _build_prompt( + derivation, + country=country, + household_summary=household_summary, + extra_context=extra_context, + trace_max_depth=trace_max_depth, + ) + response: Any = await litellm.acompletion( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + max_tokens=max_tokens, + ) + return response.choices[0].message.content.strip() diff --git a/src/policyengine/derivations/trace.py b/src/policyengine/derivations/trace.py new file mode 100644 index 00000000..a772ca23 --- /dev/null +++ b/src/policyengine/derivations/trace.py @@ -0,0 +1,189 @@ +"""Deterministic computation-trace extraction. + +Turns the live OpenFisca tracer output for a single variable into a stable +:class:`Derivation` data class. Everything here is pure and side-effect-free +once the tracer has captured the calculation. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Iterable + + +@dataclass(frozen=True) +class TraceNode: + """One node in the pruned computation tree. + + Mirrors OpenFisca's per-variable trace entry: the variable name, the + scalar value it took for the requested household, and its immediate + dependencies (each themselves a :class:`TraceNode`). Booleans surface as + Python ``bool``; numeric values surface as ``float``. + """ + + name: str + value: Any + children: tuple[TraceNode, ...] = field(default_factory=tuple) + + def to_text(self, *, max_depth: int = 8, prune_zero: bool = True) -> str: + """Render the tree as an indented text block. + + Parameters + ---------- + max_depth: + Stop descending below this depth (root has depth 0). + prune_zero: + When True (the default), zero-valued subtrees below depth 1 are + omitted from the rendering. This is the format we feed to LLMs: + keep the top-level zero categories so the model knows they were + considered, but drop the cascading zero leaves under them. + """ + + lines: list[str] = [] + _render(self, depth=0, lines=lines, max_depth=max_depth, prune_zero=prune_zero) + return "\n".join(lines) + + +@dataclass(frozen=True) +class Derivation: + """A single variable's computation, captured as a structured tree. + + ``Derivation`` is the deterministic core of an explanation. It can be + rendered as text, walked programmatically for charts, or passed to + :func:`policyengine.derivations.narrate` for a prose summary. + """ + + variable: str + value: Any + trace: TraceNode + period: Any + + def trace_text(self, *, max_depth: int = 8, prune_zero: bool = True) -> str: + """Convenience wrapper around :meth:`TraceNode.to_text`.""" + return self.trace.to_text(max_depth=max_depth, prune_zero=prune_zero) + + def top_level_contributions(self) -> list[tuple[str, Any]]: + """``[(child_variable_name, value), ...]`` for the root's children. + + Useful when you want a deterministic structured breakdown next to the + prose narrative — e.g. "the answer is the sum of these named pieces". + """ + return top_level_contributions(self) + + +def _scalar(value: Any) -> Any: + """Reduce a vectorized OpenFisca result to a single Python scalar.""" + if hasattr(value, "__len__") and len(value): + item = value[0] + else: + item = value + if hasattr(item, "item"): + item = item.item() + return item + + +def is_zero_value(value: Any) -> bool: + """True if ``value`` is the zero of its type (False, 0, 0.0). + + Exported because callers sometimes want to filter their own copies of the + tree without redefining what "zero" means. + """ + item = _scalar(value) + if isinstance(item, bool): + return item is False + if isinstance(item, (int, float)): + return item == 0 + return False + + +def _convert(node: Any) -> TraceNode: + """Convert an OpenFisca tracer node into our stable ``TraceNode`` shape.""" + return TraceNode( + name=node.name, + value=_scalar(node.value), + children=tuple(_convert(child) for child in node.children), + ) + + +def _render( + node: TraceNode, + *, + depth: int, + lines: list[str], + max_depth: int, + prune_zero: bool, +) -> None: + if depth > max_depth: + return + if prune_zero and depth > 1 and is_zero_value(node.value): + return + lines.append(" " * depth + node.name + " = " + _format_value(node.value)) + for child in node.children: + _render( + child, + depth=depth + 1, + lines=lines, + max_depth=max_depth, + prune_zero=prune_zero, + ) + + +def _format_value(value: Any) -> str: + if isinstance(value, bool): + return "True" if value else "False" + if isinstance(value, float): + return f"{value:.2f}".rstrip("0").rstrip(".") or "0" + return str(value) + + +def _find_root(roots: Iterable[Any], target: str) -> Any | None: + """Depth-first search the tracer roots for a node named ``target``.""" + for root in roots: + if root.name == target: + return root + match = _find_root(root.children, target) + if match is not None: + return match + return None + + +def derive(simulation: Any, variable: str, period: Any) -> Derivation: + """Compute ``variable`` on ``simulation`` and return a structured derivation. + + The caller is responsible for owning the ``Simulation`` and any reform on + it. ``derive`` turns the tracer on (if not already on), clears any prior + trees so the captured tree is exactly the one we asked for, runs the + calculation, and converts the resulting tree to a stable ``TraceNode``. + """ + + simulation.trace = True + if hasattr(simulation, "tracer") and hasattr(simulation.tracer, "trees"): + simulation.tracer.trees.clear() + simulation.calculate(variable, period) + + if not hasattr(simulation, "tracer") or not simulation.tracer.trees: + raise RuntimeError( + f"No trace recorded after calculating {variable!r}. " + "Ensure the simulation backend supports tracing." + ) + root = _find_root(simulation.tracer.trees, variable) + if root is None: + raise RuntimeError( + f"Tracer did not produce a root for {variable!r}. " + "This usually means the variable was already cached." + ) + return Derivation( + variable=variable, + value=_scalar(root.value), + trace=_convert(root), + period=period, + ) + + +def top_level_contributions(derivation: Derivation) -> list[tuple[str, Any]]: + """Return ``[(name, value), ...]`` for the immediate dependencies of the root. + + Children appear in the order OpenFisca recorded them. Use this when you + want a deterministic structured breakdown alongside the prose narrative. + """ + return [(child.name, child.value) for child in derivation.trace.children] diff --git a/tests/test_derivations.py b/tests/test_derivations.py new file mode 100644 index 00000000..4b3d6e96 --- /dev/null +++ b/tests/test_derivations.py @@ -0,0 +1,171 @@ +"""Tests for ``policyengine.derivations``. + +The deterministic part is exercised against a small US household; the +narration path is exercised with a fake LiteLLM client so the test stays +hermetic. +""" + +from __future__ import annotations + +import asyncio +import types + +import pytest + +from policyengine.derivations import ( + Derivation, + TraceNode, + derive, + is_zero_value, + narrate, + narrate_async, + top_level_contributions, +) + + +def _us_simulation() -> object: + """Return a US Simulation set up with a single working adult. + + Skipped if ``policyengine_us`` is not installed in this environment. + """ + pytest.importorskip("policyengine_us") + from policyengine_us import Simulation # type: ignore + + return Simulation( + situation={ + "people": { + "head": { + "age": {2026: 35}, + "employment_income": {2026: 45000}, + } + }, + "tax_units": {"tu": {"members": ["head"]}}, + "households": {"h": {"members": ["head"], "state_code": {2026: "TX"}}}, + } + ) + + +def test_is_zero_value_handles_scalars_and_arrays(): + assert is_zero_value(0) + assert is_zero_value(0.0) + assert is_zero_value(False) + assert not is_zero_value(1) + assert not is_zero_value(True) + assert not is_zero_value(0.01) + + +def test_derive_returns_structured_tree_for_us_income_tax(): + sim = _us_simulation() + derivation = derive(sim, "income_tax_before_credits", 2026) + + assert isinstance(derivation, Derivation) + assert derivation.variable == "income_tax_before_credits" + assert isinstance(derivation.value, float) + assert derivation.value > 0 + assert isinstance(derivation.trace, TraceNode) + assert derivation.trace.name == "income_tax_before_credits" + # Tree should include the canonical AGI/taxable-income path. We don't + # bind to specific numerics so the test stays stable across model + # parameter updates. + text = derivation.trace_text(max_depth=4) + assert "taxable_income" in text + assert "adjusted_gross_income" in text + + +def test_top_level_contributions_lists_immediate_children(): + sim = _us_simulation() + derivation = derive(sim, "income_tax_before_credits", 2026) + + contributions = top_level_contributions(derivation) + assert contributions, "expected at least one top-level dependency" + names = {name for name, _ in contributions} + # ``income_tax_before_credits`` is the sum of regular tax, AMT, capital + # gains tax, etc. We check for the rates path which is always present. + assert "income_tax_main_rates" in names + + +def test_trace_text_drops_zero_subtrees_below_depth_one(): + leaf_zero = TraceNode(name="dependent_zero_thing", value=0.0) + leaf_nonzero = TraceNode(name="dependent_real_thing", value=5.0) + intermediate = TraceNode( + name="payable_component", + value=0.0, + children=(leaf_zero, leaf_nonzero), + ) + root = TraceNode(name="root_var", value=5.0, children=(intermediate,)) + + text = root.to_text() + # Top-level zero is preserved (depth 1) so the caller sees the category. + assert "payable_component = 0" in text + # Zero leaves under it are dropped. + assert "dependent_zero_thing" not in text + # Non-zero leaves are kept. + assert "dependent_real_thing = 5" in text + + +def test_narrate_passes_trace_to_litellm(monkeypatch): + captured: dict[str, object] = {} + + def fake_completion(**kwargs): + captured["kwargs"] = kwargs + message = types.SimpleNamespace(content="Stub narrative.") + choice = types.SimpleNamespace(message=message) + return types.SimpleNamespace(choices=[choice]) + + monkeypatch.setitem( + __import__("sys").modules, + "litellm", + types.SimpleNamespace(completion=fake_completion), + ) + + derivation = Derivation( + variable="federal_income_tax", + value=4454.8, + period=2026, + trace=TraceNode( + name="federal_income_tax", + value=4454.8, + children=( + TraceNode(name="adjusted_gross_income", value=81820.85), + TraceNode(name="standard_deduction", value=16100), + ), + ), + ) + + narrative = narrate( + derivation, + country="us", + household_summary="TX, joint, 2 adults, ~$85k income", + ) + assert narrative == "Stub narrative." + prompt = captured["kwargs"]["messages"][0]["content"] + assert "federal_income_tax" in prompt + assert "adjusted_gross_income = 81820.85" in prompt + assert "TX, joint, 2 adults, ~$85k income" in prompt + assert "COUNTRY: US" in prompt + + +def test_narrate_async_uses_litellm_acompletion(monkeypatch): + captured: dict[str, object] = {} + + async def fake_acompletion(**kwargs): + captured["kwargs"] = kwargs + message = types.SimpleNamespace(content="Async stub.") + choice = types.SimpleNamespace(message=message) + return types.SimpleNamespace(choices=[choice]) + + monkeypatch.setitem( + __import__("sys").modules, + "litellm", + types.SimpleNamespace(acompletion=fake_acompletion), + ) + + derivation = Derivation( + variable="x", + value=1.0, + period=2026, + trace=TraceNode(name="x", value=1.0), + ) + result = asyncio.run(narrate_async(derivation)) + assert result == "Async stub." + assert captured["kwargs"]["model"] == "claude-sonnet-4-6" From aac7c69e50e7e2218fb3725221efd532e7d19a68 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 17 May 2026 16:43:38 -0400 Subject: [PATCH 2/2] derivations: keep per-entity values in the trace The previous implementation took the first element of OpenFisca's vectorised result for every node, which silently dropped every other entity's contribution. For a joint household with $45k self-employment income (head) and $40k wages (spouse), the narrative would report "the household's only income is $45,000 of self-employment income" because irs_gross_income's [45000, 40000] array was truncated to [45000]. Switch _capture to: - collapse length-1 arrays to a scalar (the common case for tax-unit / household variables), - preserve multi-entity arrays as tuples (numeric or boolean). Update _format_value to render numeric tuples as ``sum (per entity: a, b, ...)`` so summarisers see both the total and the per-person decomposition; boolean tuples render as ``[True, False]``. Update is_zero_value to recurse into tuples (every entry must be zero). Add tests covering the multi-entity render and zero check. --- src/policyengine/derivations/trace.py | 68 +++++++++++++++++++-------- tests/test_derivations.py | 22 +++++++++ 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/src/policyengine/derivations/trace.py b/src/policyengine/derivations/trace.py index a772ca23..220b710e 100644 --- a/src/policyengine/derivations/trace.py +++ b/src/policyengine/derivations/trace.py @@ -71,28 +71,45 @@ def top_level_contributions(self) -> list[tuple[str, Any]]: return top_level_contributions(self) -def _scalar(value: Any) -> Any: - """Reduce a vectorized OpenFisca result to a single Python scalar.""" - if hasattr(value, "__len__") and len(value): - item = value[0] - else: - item = value - if hasattr(item, "item"): - item = item.item() - return item +def _to_python(value: Any) -> Any: + """Convert a numpy scalar to a native Python scalar; pass through tuples.""" + if hasattr(value, "item"): + return value.item() + return value + + +def _capture(value: Any) -> Any: + """Capture an OpenFisca trace value as a Python scalar or tuple. + + Per-person variables come through as numpy arrays of length N (one per + person in the household); tax-unit / household variables come through as + length-1 arrays. Length-1 arrays collapse to a scalar so most renderings + stay terse; multi-entity arrays are preserved as tuples so that + summarising "$45,000 SE income + $40,000 wages" doesn't silently drop + the spouse's row. + """ + if hasattr(value, "__len__"): + if len(value) == 0: + return None + if len(value) == 1: + return _to_python(value[0]) + return tuple(_to_python(item) for item in value) + return _to_python(value) def is_zero_value(value: Any) -> bool: - """True if ``value`` is the zero of its type (False, 0, 0.0). + """True iff ``value`` is the zero of its type across every entity. - Exported because callers sometimes want to filter their own copies of the - tree without redefining what "zero" means. + For multi-entity (tuple) values, every entry must be falsy/zero. Exported + because callers sometimes want to filter their own copies of the tree + without redefining what "zero" means. """ - item = _scalar(value) - if isinstance(item, bool): - return item is False - if isinstance(item, (int, float)): - return item == 0 + if isinstance(value, tuple): + return all(is_zero_value(item) for item in value) + if isinstance(value, bool): + return value is False + if isinstance(value, (int, float)): + return value == 0 return False @@ -100,7 +117,7 @@ def _convert(node: Any) -> TraceNode: """Convert an OpenFisca tracer node into our stable ``TraceNode`` shape.""" return TraceNode( name=node.name, - value=_scalar(node.value), + value=_capture(node.value), children=tuple(_convert(child) for child in node.children), ) @@ -129,6 +146,19 @@ def _render( def _format_value(value: Any) -> str: + if isinstance(value, tuple): + formatted = [_format_value(item) for item in value] + if all( + isinstance(item, (int, float)) and not isinstance(item, bool) + for item in value + ): + total = sum(value) + return f"{_format_scalar(total)} (per entity: {', '.join(formatted)})" + return "[" + ", ".join(formatted) + "]" + return _format_scalar(value) + + +def _format_scalar(value: Any) -> str: if isinstance(value, bool): return "True" if value else "False" if isinstance(value, float): @@ -174,7 +204,7 @@ def derive(simulation: Any, variable: str, period: Any) -> Derivation: ) return Derivation( variable=variable, - value=_scalar(root.value), + value=_capture(root.value), trace=_convert(root), period=period, ) diff --git a/tests/test_derivations.py b/tests/test_derivations.py index 4b3d6e96..453d92b4 100644 --- a/tests/test_derivations.py +++ b/tests/test_derivations.py @@ -103,6 +103,28 @@ def test_trace_text_drops_zero_subtrees_below_depth_one(): assert "dependent_real_thing = 5" in text +def test_trace_text_shows_per_entity_values_with_sum(): + # Multi-person households: per-person values come through as tuples and + # render as "sum (per entity: a, b)" so summarisers don't silently drop + # the spouse's contribution. + root = TraceNode( + name="irs_gross_income", + value=(45000.0, 40000.0), + ) + text = root.to_text() + assert "85000" in text + assert "45000" in text + assert "40000" in text + assert "per entity" in text + + +def test_is_zero_value_handles_multi_entity_tuples(): + assert is_zero_value((0, 0, 0)) + assert is_zero_value((False, False)) + assert not is_zero_value((0, 1)) + assert not is_zero_value((False, True)) + + def test_narrate_passes_trace_to_litellm(monkeypatch): captured: dict[str, object] = {}